Just to follow up on what @tqchen summarized previously, here’s my understanding:
frontend converters
We want users who write frontend converters be aware that certain operators are stateful. We can encourage them to write these operations in A1 style. For instance:
def _mx_dropout_train(inputs, attrs, module):
rate = attrs.get_float("p", 0.5)
global_state = module['prng']
state_ref = relay.RefCreate(global_state)
read_state = relay.RefRead(state_ref)
# the dropout_train operator outputs both y and the new state
y_state = _op.nn.dropout_train(inputs[0], read_state, rate=rate)
# write back new state, return y
write_state = relay.RefWrite(state_ref, y_state[1])
y = relay.Let(relay.var('ref_write'), write_state, y_state[0])
return y
where module['prng']
is a global variable representing the PRNG state in the module. As of now, global variables currently are only used to represent functions. We need to extend it to represent the random state, too.
rewriting A1-style programs to A2 -style ones
Let’s say we have a function below with stateful ops:
def @func1(%x) {
%0 = ref(@prng_state);
%1 = %0^;
%2 = nn.dropout_train(%x, %1, rate=0.7f)
%3 = %2.1;
let %ref_write: () = (%0 := %3);
%2.0
}
In the rewriting pass, we detect that the global random state is used, and replace its references to the following:
def @func1_rewritten(%x, %state) {
%2 = nn.dropout_train(%x, %state, rate=0.7f)
(%2.0, %2.1)
}
Note that the function output type is changed to a tuple containing the new state. Meanwhile we need to update all CallNodes for this function accordingly. Here is another example:
def @long_func(%x) {
%0 = ref(@prng_state);
%1 = %0^;
%2 = nn.dropout_train(%x, %1, rate=0.7f)
%3 = %2.1;
%4 = (
let %ref_write1: () = (%0 := %3);
%2.0
);
%5 = %0^;
%6 = nn.dropout_train(%4, %5, rate=0.1f)
%7 = %6.1;
let %ref_write: () = (%0 := %7);
%6.0
}
===>
def @long_func_rewritten(%x, %state) {
%2 = nn.dropout_train(%x, %state, rate=0.7f)
%3 = %2.1;
%4 = %2.0;
%6 = nn.dropout_train(%4, %3, rate=0.1f)
(%6.1, %6.0)
}
Note that the pass implementation requires tracking the latest value of the global variable within each scope. For instance, the program below:
def @func2(%x, %y) { # returns tensor
if (%x) {
add(%x, %y)
} else {
func1(%y)
}
}
would be rewritten to:
def @func2(%x, %y, %state) { # returns (tensor, state) for both branches
if (%x) {
(add(%x, %y), %state) # the original state is also returned
} else {
func1_rewritten(%y, %state) # returns the new state
}
}
Since the pass requires evaluations within each scope, it would be easier to implement the pass after the program is already transformed to the bblock form.
discussions
what type do we use for the random state?
- option a: use the empty tuple type. The runtime actually uses the global state, and it relies on the deterministic execution order of the program to ensure reproducibility.
- option b: add a new type (e.g. TypeRandState), and the random state Object actually carries the data structure used for generating random numbers (e.g.
std::mt19937
). The state is passed around in the program, and invoking an operator with the same state object always leads to the same deterministic outputs.
@junrushao @haichen @MarisaKirisame @ziheng would you like to provide some suggestions/comments?