Two missing pieces for training

Hi Zilin!

I agree with your points- let me contribute some information for what I’ve been working on at OctoML for TVM training and expand on your questions.

  1. Regarding optimizers, there are (at least) two paths as you mentioned. The first one is to inline the optimizer into the forward+backward graph, and the second is to compile a separate module which performs the weight update step and feed the gradients. Note that optimizers often have internal state (such as velocity for momentum SGD, other gradient statistics, etc.), so we should view an optimizer as an update function of (weights, grads, state) -> (weights', state'). Naturally, we can define a Relay expression which represents the updated weights and state- this is nice as inlining the optimizer is as simple as plugging in the update expression to the end of the graph and passing in the gradient exprs (note we also need to modify the forward+backward graph to accept optimizer state as input). The benefit of inlining (ignoring distributed problems for now) is that it could allow the optimizer step to be fused into the gradient computation, which may bring some perf improvements. Regardless, by defining the optimizer functionally this way, we could also compile the update step externally and avoid inlining easily.

  2. Regarding training vs inference mode, note that TVM already has a distinction in the compilation workflow via the SimplifyInference pass- for compiling training graphs that include dropout and batch norm, disabling SimplifyInference will keep these operators. Regarding side effects, I think it’d be great if we could keep things functional and have the graph be (..., state) -> (..., state') as it relieves the burden on TVM of reasoning about mutation. That being said, it may be possible in the future to use some kind of references for handling state, but currently we don’t have much support for this.

  3. I also wanted to mention another point here which is sparse tensor support. Notably, some models like DLRM create large embedding tables for which generating a dense gradient is not advisable (assuming the embedding weight is never used for other computations, which would create dense grads e.g. in the case of BERT decoding). For this we would need to better support sparse tensors natively in order for the AD algorithm and optimizer to handle the gradient computation and weight update.

Finally, I wanted to mention that internally at OctoML we’ve hit some nice milestones of being able to train BERT and DLRM end-to-end in TVM using this approach (however, we had to do a bit of hacking to get DLRM sparse gradient update to work). We’re working on cleaning up the code and preparing for an open source release, so stay tuned! Let me know if you have any other questions :slight_smile:

3 Likes