Two missing pieces for training

Hi all, I started to help with the training support of tvm lately. What I hoped to do is training a trivial mnist model by converting the official pytorch example to tvm.

After implementing the nll_loss op (which is under reviewing) and its gradient, I successfully get the correct gradient value by commenting out the dropout part of the model. However, to complete the training of this tiny model, I believe there are still missing pieces and I hope that we could have some discussion in this thread, so that we could have a clear roadmap when contributing for the training support~

Apply Gradient and Optimizers

To update the weights in each iteration, it seems that we need to convert the returned gradients back to pytorch’s tensor. It’s intolerable to copy the parameters back and forth in a high performance training. Therefore, we need to save and update the state of the weights in tvm. I wonder if the community have any thoughts on this.

In my opinion, we’d better not make the gradient update or optimizer kernels in the same relay graph as the rest of the training. Though the idea of a unified IR graph may be beautiful, it would be hard to implement the distributed training support since we need to add the communication ops to the relay graph now. If we could separate the pure forward and backward graph and the weight updating part, we could, for example, add a new tvm support to horovod and easily add the data parallel functionality.

Separate Training and Inference Mode of Ops

As we already got the PRNG support in relay, we could write a train dropout with an extra key argument. And to add the train dropout to the graph, especially from the frontend parser, we need some way to identify if tvm is in training mode or inference mode and do the parsing base on the mode. Also, as the training mode often comes with side effect, we need to settle how to maintain the state for them (e.g. the prng key of dropout, moving mean and variance for batchnorm). One possible solution is saving the state out side the graph and use them as inputs. This may be coherent with separating the gradient updates and other calculations.

Thank you for your time on this post~ I genuinely hope we could have some conclusion with these two parts.

Gently ping people who may be interested. @tqchen @tkonolige @altanh @MarisaKirisame @junrushao1994 @eric-haibin-lin

3 Likes

Hi Zilin!

I agree with your points- let me contribute some information for what I’ve been working on at OctoML for TVM training and expand on your questions.

  1. Regarding optimizers, there are (at least) two paths as you mentioned. The first one is to inline the optimizer into the forward+backward graph, and the second is to compile a separate module which performs the weight update step and feed the gradients. Note that optimizers often have internal state (such as velocity for momentum SGD, other gradient statistics, etc.), so we should view an optimizer as an update function of (weights, grads, state) -> (weights', state'). Naturally, we can define a Relay expression which represents the updated weights and state- this is nice as inlining the optimizer is as simple as plugging in the update expression to the end of the graph and passing in the gradient exprs (note we also need to modify the forward+backward graph to accept optimizer state as input). The benefit of inlining (ignoring distributed problems for now) is that it could allow the optimizer step to be fused into the gradient computation, which may bring some perf improvements. Regardless, by defining the optimizer functionally this way, we could also compile the update step externally and avoid inlining easily.

  2. Regarding training vs inference mode, note that TVM already has a distinction in the compilation workflow via the SimplifyInference pass- for compiling training graphs that include dropout and batch norm, disabling SimplifyInference will keep these operators. Regarding side effects, I think it’d be great if we could keep things functional and have the graph be (..., state) -> (..., state') as it relieves the burden on TVM of reasoning about mutation. That being said, it may be possible in the future to use some kind of references for handling state, but currently we don’t have much support for this.

  3. I also wanted to mention another point here which is sparse tensor support. Notably, some models like DLRM create large embedding tables for which generating a dense gradient is not advisable (assuming the embedding weight is never used for other computations, which would create dense grads e.g. in the case of BERT decoding). For this we would need to better support sparse tensors natively in order for the AD algorithm and optimizer to handle the gradient computation and weight update.

Finally, I wanted to mention that internally at OctoML we’ve hit some nice milestones of being able to train BERT and DLRM end-to-end in TVM using this approach (however, we had to do a bit of hacking to get DLRM sparse gradient update to work). We’re working on cleaning up the code and preparing for an open source release, so stay tuned! Let me know if you have any other questions :slight_smile:

2 Likes

Thank you for the information! I’m really looking forward to the BERT and DLRM model :smiley:

Note that optimizers often have internal state (such as velocity for momentum SGD, other gradient statistics, etc.), so we should view an optimizer as an update function of (weights, grads, state) -> (weights', state') . … Regardless, by defining the optimizer functionally this way, we could also compile the update step externally and avoid inlining easily.

I love the idea of representing the optimizer as a relay function and letting the user choose whether to inline it! I’m totally fine as long as user could separate the main calculation and the gradient update~

Regarding training vs inference mode, note that TVM already has a distinction in the compilation workflow via the SimplifyInference pass- for compiling training graphs that include dropout and batch norm, disabling SimplifyInference will keep these operators.

If I have understand this correctly, this means we will not have a separate training version and inference version for ops like dropout and batchnorm. Then, all the dropout in the relay graph will be interpreted as the training one and an additional prng key argument will be added to the main function. In that case, if we wants to do evaluation in the training (after each epoch, for example), we need to transform the module with SimplifyInference every times? And what if there are an op that behaves differently in training and inference but could not be eliminated by the SimplifyInference?

BTW, it there any plan on lowering part of the graph from frameworks to tvm, like the nice blog Bridging PyTorch and TVM did with BERT. In practice, we may encounter people using all kinds of rare ops in their training, even some custom ops. It would be great if we can just optimize part of the origin training graph (for example, the common backbone or loss) and leave the other ops to the origin frameworks. The main barrier would be how to separate the forward and backward graph as mentioned in the blog. And I wonder if you have any thoughts on this? Thank you~

In that case, if we wants to do evaluation in the training (after each epoch, for example), we need to transform the module with SimplifyInference every times? And what if there are an op that behaves differently in training and inference but could not be eliminated by the SimplifyInference?

Good point- currently the approach I’m using is to compile a separate training and inference module once for the same graph (with the inference module compiled using SimplifyInference). The benefit of this approach is that we may be able to apply more optimizations to the inference module (e.g. fusion because no backward dependencies), which could speed up evaluation. The downside is that we need to 1) compile an external loss function for evaluation (unless you want to inline the loss, which makes prediction impossible) and 2) takes longer to compile and larger memory footprint.

And what if there are an op that behaves differently in training and inference but could not be eliminated by the SimplifyInference ?

Ultimately we have 2 options:

  1. Disallow this kind of behavior and explicitly create training/inference versions (e.g. nn.dropout.infer, nn.dropout.train)
  2. Allow some global TVM state which changes the runtime behavior of certain ops.

Option 2 has some broader implications about the TVM runtime and compilation which I’m not sure we want to get into (for example, how do we support this kind of state in AOT or microTVM?).

BTW, it there any plan on lowering part of the graph from frameworks to tvm, like the nice blog Bridging PyTorch and TVM did with BERT. In practice, we may encounter people using all kinds of rare ops in their training, even some custom ops. It would be great if we can just optimize part of the origin training graph (for example, the common backbone or loss) and leave the other ops to the origin frameworks. The main barrier would be how to separate the forward and backward graph as mentioned in the blog. And I wonder if you have any thoughts on this?

Yeah I followed that blog post quite closely during my development as well, and I think its pretty impressive. For context, the approach we decided to go with was design a standalone TVM training framework (with modules like PyTorch, etc.) as this gave more control over the input graph and training API. However, I strongly agree that there is a large ecosystem around PyTorch and it would be hugely beneficial to integrate TVM in a way that avoids users needing to port their models. We haven’t explored this path much yet so I don’t have much to comment, but it’s definitely an important direction!

1 Like

Thank you for your detailed explanation! I’m really looking forward to the publicity of the BERT training! I wonder is there any thing I can do before the release? just afraid of reinventing the wheel :stuck_out_tongue:

1 Like

hey sorry for the delayed reply- in terms of the timeline we are trying to upstream and open source the work around Q3 this year, and in the mean time I’m going to be starting upstreaming TVM improvements. Generally we still could use a good amount of training operator coverage (gradients, loss ops, dropout, etc) which should be complementary to other changes, please let me know if you have any specific ideas or questions!

1 Like