[RFC] Handling Effect in TVM and Relay

Hey everyone, reviving this thread as @tkonolige, @jroesch, @antinucleon and I have been experimenting with adding some PRNG support to Relay.

While nothing is finalized, we are currently trying the Jax approach of explicit user-side PRNG key management. The reasoning is as follows: in general, most networks use PRNG quite simply, so the user defining the network

  1. in Relay, can easily pass around a key while they build the network
  2. in importers, can have a mutable field storing the key (in Python) and replace the key by splitting after each use.

For an example taken from Relay (VGG) and modified as suggested:

def get_classifier(input_data, num_classes, prng_key):
    """Get VGG classifier layers as fc layers."""
    left, right = relay.random.split(prng_key)
    flatten = relay.nn.batch_flatten(data=input_data)
    fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
    relu6 = relay.nn.relu(data=fc6)
    drop6 = relay.nn.dropout(data=relu6, rate=0.5, key=left)
    fc7 = wrapper.dense_add_bias(data=drop6, units=4096, name="fc7")
    relu7 = relay.nn.relu(data=fc7)
    drop7 = relay.nn.dropout(data=relu7, rate=0.5, key=right)
    fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
    return fc8

Note that we must only use a PRNG key once (treat it as a linear resource), but here we don’t need a key again so we can use both results from random.split. Then to use this in defining the full network, we can simply write (for example)

classifier = get_classifier(feature, num_classes, relay.random.key(0))

For an example with an importer (this is Caffe):

    def convert_dropout(self, op):
        """ Convert Dropout layer """
        inputs = op.bottom
        input_name = inputs[0]
        next_key, dropout_key = _op.random.split(self.prng_key)
        self.prng_key = next_key  # we can wrap this in some helper method

        params = dict()
        dropout_params = op.dropout_param

        params["rate"] = dropout_params.dropout_ratio

        in_expr = self.exp_tab.get_expr(input_name)
        out = _op.nn.dropout(in_expr, **params, key=dropout_key)
        return out

Would love to hear your thoughts and if I missed any kind of edge case here! I’m also happy to try and write some pseudocode examples for more complicated use cases if anyone is interested.

EDIT: I forgot to mention that, as @eric-haibin-lin mentioned, this means we cannot plug in off-the-shelf PRNG kernels. However, we have written a splittable PRNG kernel in TIR (Threefry), so some of the ground work is done. This will also let us run on GPU.

1 Like