As TVM is getting more and more ready for training, the demand for effect had been rising. A lot of training operators are effectful. For example, batch norm keeps track of global states, while dropout is nondeterministic.
However, having a effectful operator is harmful to TVM. It will make passes brittle, as all Graph Level pass needs to handle effect in order to be correct. (For example, a Common SubExpression Elimination pass must make sure it didn’t fuse two dropout node). Having some operator as effectful also make some optimization downright impossible (gradient checkpointing, recomputation vs moving across machine, memorization). It also hurt reproducibility because passes might change the order operators are executed, thus changing the result subtly.
However, Jax had shown that it is possible to build a deep learning framework without any effect at all, and we can learn from that. I propose to alleviate this problem and try to get a balance between IR property and programmability by restricting the operator into three kinds, by adopting pytorch’s solution.
Input operators. these operators take no argument, and return some output, read from some external interface. For example, reading from keyboard/file is input operators.
Output operators. these operators take some input, output it on some external interface, and return no output. Example include, printing a tensor, logging it on tensorboard, or saving it to a file.
Compute operators. These operators take input, return output, but must not have any effect. The runtime is free to execute these operators 0(when the value is not needed), 1, or multiple time.
This effectively mean that all computation must be pure.
However, how can we handle effectful compute operators, such as batchnorm and dropout?
Basically, there are two kinds of effect left that are in deep learning: state, and randomness.
We can use Relay’s Reference feature to handle state. Suppose there is a stateful operator, raw_f, that read and write a Tensor, x. We can allocate a global Reference in the module, holding tensor x. Now, the stateful operator can simply take x as input, outputting a new x. There will be a function, f, that read x, feed it into raw_f, and write x with the output.
A nondeterministic operator is simply an effectful operator that read and write to a Pseudo-Random-Number-Generator.