[RFC] Handling Effect in TVM and Relay

As TVM is getting more and more ready for training, the demand for effect had been rising. A lot of training operators are effectful. For example, batch norm keeps track of global states, while dropout is nondeterministic.

However, having a effectful operator is harmful to TVM. It will make passes brittle, as all Graph Level pass needs to handle effect in order to be correct. (For example, a Common SubExpression Elimination pass must make sure it didn’t fuse two dropout node). Having some operator as effectful also make some optimization downright impossible (gradient checkpointing, recomputation vs moving across machine, memorization). It also hurt reproducibility because passes might change the order operators are executed, thus changing the result subtly.

However, Jax had shown that it is possible to build a deep learning framework without any effect at all, and we can learn from that. I propose to alleviate this problem and try to get a balance between IR property and programmability by restricting the operator into three kinds, by adopting pytorch’s solution.

Input operators. these operators take no argument, and return some output, read from some external interface. For example, reading from keyboard/file is input operators.

Output operators. these operators take some input, output it on some external interface, and return no output. Example include, printing a tensor, logging it on tensorboard, or saving it to a file.

Compute operators. These operators take input, return output, but must not have any effect. The runtime is free to execute these operators 0(when the value is not needed), 1, or multiple time.

This effectively mean that all computation must be pure.

However, how can we handle effectful compute operators, such as batchnorm and dropout?

Basically, there are two kinds of effect left that are in deep learning: state, and randomness.

We can use Relay’s Reference feature to handle state. Suppose there is a stateful operator, raw_f, that read and write a Tensor, x. We can allocate a global Reference in the module, holding tensor x. Now, the stateful operator can simply take x as input, outputting a new x. There will be a function, f, that read x, feed it into raw_f, and write x with the output.

A nondeterministic operator is simply an effectful operator that read and write to a Pseudo-Random-Number-Generator.

1 Like

As an example, let’s consider the dropout function.

An effectful dropout operator will be of type Tensor -> Tensor, but under the proposal, it should have type (Tensor, PRNG) -> (Tensor, PRNG).

Meanwhile, there will be a global rprng:(Ref PRNG) in the module, alongside a wdropout (wrapped_dropout), of type Tensor -> Tensor. It is implemented by reading rprng, calling dropout with the input and rprng, writing the prng output to rprng, and returning the output tensor. This design let the user write in a effectful style, but we can transform the code to get an effect free dataflow graph, with explicit PRNG chaining. Here is how.

Suppose the user input is a graph with wdropout, we first ANF the program to get a well-defined program. We can then use PartialEval/Inlining to remove all call to wdropout, and all but the first/last read/write will be eliminated. We can then convert it into a Program that read once, go into graph mode, write once, then return.

THanks @MarisaKirisame can we write down examples of programs you talk about in the text format(before and after the PartialEval?)

Thanks Marisa for brining this up! Overall I like this proposal, especially how to eliminate the all refs except the first/last ones.

There is one thing that we could potentially concern - thread safety on global resources. Say people are doing multi-thread inference on the same module (you may argument that it is incorrect, but there are actually many industrial use cases), if there is an operator like dropout who takes a global state, it races; batch norm has similar problem too

We shall add AtomicUpdate to Reference. It will then eliminate to only a single AtomicUpdate.

I don’t think we need to complicate the issue here. In particular wrt to threadsafety and atomic update(we don’t have to introduce atomic update).

Note that the final goal(which has explicitly pass the state around won’t have this problem). And the program semantics that users write(which relies on a PRNG state, that we can define as global to the module or thread local) has the problem.

The main question is how easy it is to convert the program that uses the PRNG state to the explicit passing style. After that conversion, the execution engine can then decide how to implement the PRNG.

I agree with automic update :slight_smile:

But the essential problem here is that we are adding a virtual dependency between operators who have the same effect. For example, in a Relay function, there might be two dropouts. If we force them to share the same global PRNG, those two dropouts will be serialized.

We could insert ref read/write around the operators who have effects - this ensures correctness. Then as Marisa mentioned, we could rely on partial evaluation to optimize most of them out.

What is happening here is we represented a concept (nondeterminism) with it’s implementation (PRNG). The principled solution is to add nondeterminism into Relay, then we can do optimization on the relay level, knowing two nondet function can be swapped.

1 Like

Per offline discussion with @tqchen. Let’s limit the scope to the usage of PRNG states.

Solution A1. Allocating single global reference. As suggested by @MarisaKirisame, we can maintain a global state inside a module, which indicates the PRNG key. Every time we encounter a stateful operator f, we insert a RefRead before f, whose result is passed to f as an additional argument, and then put a RefWrite after f.


fn my_func(x) {
    let y = Dropout(x);

will be transformed to

fn my_func(x) {
    let prng_key = RefRead(global_prng_state);
    let (y, new_prng_key) = Dropout(x, prng_key);
    RefWrite(global_prng_state, new_prng_key);

This transformation should be put on the earliest stage to ensure correctness to make sure all our existing passes functioning properly.

Solution A2. Passing states everywhere. We can also add an extra argument called prng_state to all functions (including closures), and add an extra return item as well. The newly generated PRNG state will shadow the old ones instead of using mutation. Note that in this case, we should handle if expressions properly.


fn my_func(x) {
    let y = Dropout(x);

will be transformed to

fn my_func(x, prng_state) {
    let (y, prng_state) = Dropout(x, prng_state);
    (y, prng_state)

The extra argument could be eliminated if and only if all callees in the call-graph do not use it.

From Solution A1 to A2. There is indeed a connection between A1 and A2. In many cases, using partial evaluation will reduce Solution A1 to A2.

A2 is essentially monad. While it might be a good idea here, I do not think it is a good solution to the effect problem.

Monad is a programming trick in Haskell or other purely functional programming language, that is pure and must respect referential transparency (calling the same function twice has the same result). Monad encapsulate effect by, in a sense, creating a sub-language that can do whatever the Metalanguage(Haskell) do, with api for certain impure effect.

For example, here is the monad that correspond to arbitrary state:

State a b = --have a global variable a, and eventually return b
   ReturnState b -- a b value in Haskell
|  Read (a -> State a b) -- read the global variable, then continue executing with that value
|  Write a (State a b) -- write to the global variable, then continue executing

However, while the outer language is pure, the inner language (State, in this example) still has effect, and many pass need to do effect-specific optimization (e.g. if two write do not alias, you can swap them). Operating on the pure fragment merely let us build up this monad purely, but generally cannot peak into the structure. Using Monad Also bring tons of programming problem (what happend if you have two monad and you want to combine them? this is a very hard problem in haskell, and ppl are still designing new solution to these day)

The more principled solution IMO, is to add effect directly into relay, and add optimizations on those level. If the above A2 solution is desired, we can add a “local variable” effect, which should be easier to optimize.

1 Like

@jroesch got some times to hop in?

Hey Marisa,

You are right, it is essentially (very simplified) monad - this is why I limit the discussion scope only to PRNG, and why I say “in many cases” A1 can be reduced to A2, but not all cases. Also I agree with you that bringing in full monad brings us tons of problems, and it is impractical for now.

To make it practical, let’s simplify the problem by limiting the scope. Let’s step back and look at what the problem looks like: contributors of frontend importers want to use 1) PRNGs; 2) batch normalization (cuDNN’s implementation contains inplace mutation, so many libraries they just follow this design unfortunately).

For PRNG, just to ensure reproducibility (or referential transparency to be exact).

For batch normalization, from my point of view, it is more correct just to make it as pure as like: BatchNorm(x, running_mean, running_var) -> (y, new_running_mean, new_running_var). But we need to make sure frontend importer work this way.

How does tvm do with inplace mutation in general btw? Does tvm has plan to support inplace? There should be some optimization done to turn Ref of Tensor update into the internal value update (if possible). I think mxnet does this and we should too.

MXNet supports inplace mutation. To deal with races, it has an almost general-purpose dependency engine, which builds a dependency graph inside…https://mxnet.apache.org/api/architecture/note_engine

Immediate idea: if this is costly, we can use static analysis to trace out raceless case. But I assume nothing’s cost come close to calling a gpu kernel…sigh

Also, in the above code, we can write a pass which can gurantee to turn A1 to A2 - by essentially injecting ‘monad’ into the program. This is still transparent to the end user though.

1 Like

One thing worth noting is that we are not trying to optimize for general purpose programs, but deep learning programs that happens to have (a few) state updates due to PRNG, parameter update and IO.

We should bear in mind is that most developers does not know how to deal with states, and most optimizations do not need to deal with them. So in an ideal case we want to isolate the pure blocks into basic blocks, as in Basic Block Normal Form and not touching state updates in most cases.

While it is possible to do alias analysis and move mutation around, we have to admit that the cost of writing compilers passes for this style of program(as in A1) is higher. The explicit passing style(monad, A2), on the other hand, brings a larger region of “pure programs”, which means the basic block optimizers could likely bring more benefits(by being able to reorder other ops with random, for example). The explicit passing style is also more friendly to be transformed into execution graph that runs on multiple machines(devices), in which case we would need to pass the state around.

On the other hand, from the frontend writer’s perspective, it could be tricky to write all state update in the monad style, and the program in A1 is much easier to write due to the fact that not having to pass the states around.

So to summarize some key points:

  • It is relatively easy to write optimizations for pure basic blocks and we want to enable that(by bringing bigger basic blocks when possible).
  • The state mutation program breaks basic blocks apart(because we want to make sure basic blocks remain pure), so the program in A1 is less friendly for optimization (of course alias analysis would help, but that adds complexity to pass development). Explicit state passing style in a bigger block is more desirable for optimization(of course we could still use state read/write as source/sink of the state passing)
  • On the other hand, from the frontend perspective, most programs are written in the A1 style.

Considering the above points, I think it is important to:

  • Design interface that explicitly passes state around(while allow ref to record state update)
  • Design rewriting passes to rewrite A1 style program into A2
    • The rough idea is to collect all the affected functions, and their states, add the state as function argument and return values.
    • Note that the rewriting can be done partially, which only removes some of the state mutations, but still keeps state read/write at an upper level source sink.
  • Encourage users to write basic block optimizations while dealing with functions that explicitly passes the state

For Relay/TVM, are we mainly concerned with computer operators (random number generators), or we want to support IO input / IO output as well?

I think there is no much difference in this case