[RFC] Handling Effect in TVM and Relay

I took a look at Jax’s approach, just want to summarize and share below:

  1. A random number function requires an explicit random state (a “key”) as one of the arguments. Calling the random number does not change the key/random state:

nums = random.normal(key, shape=(1,))

  1. To get a different random state, users need to call the split function explicitly:
_, subkey = random.split(key); 
different_nums = random.normal(subkey, shape=(1,))
  1. In a training program, the keys are split and replaced outside the JIT program in python. The JIT program just takes all new keys as inputs and they are immutable.

Since random.normal does not change the state, it implies most off-the-shelf random number generator cannot be plugged in. Jax designed their own PRNG module and implementation.

3 Likes