Hey everyone, reviving this thread as @tkonolige, @jroesch, @antinucleon and I have been experimenting with adding some PRNG support to Relay.
While nothing is finalized, we are currently trying the Jax approach of explicit user-side PRNG key management. The reasoning is as follows: in general, most networks use PRNG quite simply, so the user defining the network
- in Relay, can easily pass around a key while they build the network
- in importers, can have a mutable field storing the key (in Python) and replace the key by splitting after each use.
For an example taken from Relay (VGG) and modified as suggested:
def get_classifier(input_data, num_classes, prng_key):
"""Get VGG classifier layers as fc layers."""
left, right = relay.random.split(prng_key)
flatten = relay.nn.batch_flatten(data=input_data)
fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
relu6 = relay.nn.relu(data=fc6)
drop6 = relay.nn.dropout(data=relu6, rate=0.5, key=left)
fc7 = wrapper.dense_add_bias(data=drop6, units=4096, name="fc7")
relu7 = relay.nn.relu(data=fc7)
drop7 = relay.nn.dropout(data=relu7, rate=0.5, key=right)
fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
return fc8
Note that we must only use a PRNG key once (treat it as a linear resource), but here we don’t need a key again so we can use both results from random.split. Then to use this in defining the full network, we can simply write (for example)
classifier = get_classifier(feature, num_classes, relay.random.key(0))
For an example with an importer (this is Caffe):
def convert_dropout(self, op):
""" Convert Dropout layer """
inputs = op.bottom
input_name = inputs[0]
next_key, dropout_key = _op.random.split(self.prng_key)
self.prng_key = next_key # we can wrap this in some helper method
params = dict()
dropout_params = op.dropout_param
params["rate"] = dropout_params.dropout_ratio
in_expr = self.exp_tab.get_expr(input_name)
out = _op.nn.dropout(in_expr, **params, key=dropout_key)
return out
Would love to hear your thoughts and if I missed any kind of edge case here! I’m also happy to try and write some pseudocode examples for more complicated use cases if anyone is interested.
EDIT: I forgot to mention that, as @eric-haibin-lin mentioned, this means we cannot plug in off-the-shelf PRNG kernels. However, we have written a splittable PRNG kernel in TIR (Threefry), so some of the ground work is done. This will also let us run on GPU.