[RFC] Handling Effect in TVM and Relay

I took a look at Jax’s approach, just want to summarize and share below:

  1. A random number function requires an explicit random state (a “key”) as one of the arguments. Calling the random number does not change the key/random state:

nums = random.normal(key, shape=(1,))

  1. To get a different random state, users need to call the split function explicitly:
_, subkey = random.split(key); 
different_nums = random.normal(subkey, shape=(1,))
  1. In a training program, the keys are split and replaced outside the JIT program in python. The JIT program just takes all new keys as inputs and they are immutable.

Since random.normal does not change the state, it implies most off-the-shelf random number generator cannot be plugged in. Jax designed their own PRNG module and implementation.

3 Likes

Just to follow up on what @tqchen summarized previously, here’s my understanding:

frontend converters

We want users who write frontend converters be aware that certain operators are stateful. We can encourage them to write these operations in A1 style. For instance:

def _mx_dropout_train(inputs, attrs, module):
    rate = attrs.get_float("p", 0.5)
    global_state = module['prng']
    state_ref = relay.RefCreate(global_state)
    read_state = relay.RefRead(state_ref)
    # the dropout_train operator outputs both y and the new state 
    y_state = _op.nn.dropout_train(inputs[0], read_state, rate=rate)
    # write back new state, return y 
    write_state = relay.RefWrite(state_ref, y_state[1])
    y = relay.Let(relay.var('ref_write'), write_state, y_state[0])
    return y

where module['prng'] is a global variable representing the PRNG state in the module. As of now, global variables currently are only used to represent functions. We need to extend it to represent the random state, too.

rewriting A1-style programs to A2 -style ones

Let’s say we have a function below with stateful ops:

def @func1(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  let %ref_write: () = (%0 := %3);
  %2.0
}

In the rewriting pass, we detect that the global random state is used, and replace its references to the following:

def @func1_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  (%2.0, %2.1)
}

Note that the function output type is changed to a tuple containing the new state. Meanwhile we need to update all CallNodes for this function accordingly. Here is another example:

def @long_func(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  %4 = (
    let %ref_write1: () = (%0 := %3);
    %2.0
  );
  %5 = %0^;
  %6 = nn.dropout_train(%4, %5, rate=0.1f) 
  %7 = %6.1;
  let %ref_write: () = (%0 := %7);
  %6.0
}

===> 

def @long_func_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  %3 = %2.1;
  %4 = %2.0;
  %6 = nn.dropout_train(%4, %3, rate=0.1f) 
  (%6.1, %6.0)
}

Note that the pass implementation requires tracking the latest value of the global variable within each scope. For instance, the program below:

def @func2(%x, %y) { # returns tensor
  if (%x) {
    add(%x, %y)
  } else {
    func1(%y)
  }
}

would be rewritten to:

def @func2(%x, %y, %state) {  # returns (tensor, state) for both branches
  if (%x) {
    (add(%x, %y), %state) # the original state is also returned
  } else {
    func1_rewritten(%y, %state) # returns the new state
  }
}

Since the pass requires evaluations within each scope, it would be easier to implement the pass after the program is already transformed to the bblock form.

discussions

what type do we use for the random state?

  • option a: use the empty tuple type. The runtime actually uses the global state, and it relies on the deterministic execution order of the program to ensure reproducibility.
  • option b: add a new type (e.g. TypeRandState), and the random state Object actually carries the data structure used for generating random numbers (e.g. std::mt19937). The state is passed around in the program, and invoking an operator with the same state object always leads to the same deterministic outputs.

@junrushao @haichen @MarisaKirisame @ziheng would you like to provide some suggestions/comments?

3 Likes

Hey @eric-haibin-lin, Thank you for the valuable examples!

I was thinking if we can further alleviate frontend developers’ burden in writing operators with side effects. For example, frontend developers are only required to produce the program like below:

fn @just_a_dropout(%x) {
  let %y = _stateful_dropout(%x);
  %y
}

and then we provide a pass to replace _stateful_dropout properly:

fn @just_a_dropout(%x) {
  let %prng = RefRead(%global_var_prng_ref);
  let (%y, %new_prng) = dropout(%x, %prng);
  RefWrite(%global_var_prng_ref, %new_prng);
  %y
}

Note that this approach requires that we ignore potential dependency issues. For example, imagine if there are two parallel dropouts with no dependency between each other, then this approach would add an arbitrary dependency. However, in our case of neural networks, it doesn’t really matter.

1 Like

I dont have problems with rewriting to state passing style, and I agree with the potential type issues to the PRNG key. As @eric-haibin-lin mentioned, it can be either a unit type (empty tuple), or a new type to be introduced to relay. Would love to hear more from @tqchen, @MarisaKirisame and @jroesch.

Hey everyone, reviving this thread as @tkonolige, @jroesch, @antinucleon and I have been experimenting with adding some PRNG support to Relay.

While nothing is finalized, we are currently trying the Jax approach of explicit user-side PRNG key management. The reasoning is as follows: in general, most networks use PRNG quite simply, so the user defining the network

  1. in Relay, can easily pass around a key while they build the network
  2. in importers, can have a mutable field storing the key (in Python) and replace the key by splitting after each use.

For an example taken from Relay (VGG) and modified as suggested:

def get_classifier(input_data, num_classes, prng_key):
    """Get VGG classifier layers as fc layers."""
    left, right = relay.random.split(prng_key)
    flatten = relay.nn.batch_flatten(data=input_data)
    fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
    relu6 = relay.nn.relu(data=fc6)
    drop6 = relay.nn.dropout(data=relu6, rate=0.5, key=left)
    fc7 = wrapper.dense_add_bias(data=drop6, units=4096, name="fc7")
    relu7 = relay.nn.relu(data=fc7)
    drop7 = relay.nn.dropout(data=relu7, rate=0.5, key=right)
    fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
    return fc8

Note that we must only use a PRNG key once (treat it as a linear resource), but here we don’t need a key again so we can use both results from random.split. Then to use this in defining the full network, we can simply write (for example)

classifier = get_classifier(feature, num_classes, relay.random.key(0))

For an example with an importer (this is Caffe):

    def convert_dropout(self, op):
        """ Convert Dropout layer """
        inputs = op.bottom
        input_name = inputs[0]
        next_key, dropout_key = _op.random.split(self.prng_key)
        self.prng_key = next_key  # we can wrap this in some helper method

        params = dict()
        dropout_params = op.dropout_param

        params["rate"] = dropout_params.dropout_ratio

        in_expr = self.exp_tab.get_expr(input_name)
        out = _op.nn.dropout(in_expr, **params, key=dropout_key)
        return out

Would love to hear your thoughts and if I missed any kind of edge case here! I’m also happy to try and write some pseudocode examples for more complicated use cases if anyone is interested.

EDIT: I forgot to mention that, as @eric-haibin-lin mentioned, this means we cannot plug in off-the-shelf PRNG kernels. However, we have written a splittable PRNG kernel in TIR (Threefry), so some of the ground work is done. This will also let us run on GPU.

1 Like