Hi all, I started to help with the training support of tvm lately. What I hoped to do is training a trivial mnist model by converting the official pytorch example to tvm.
After implementing the
nll_loss op (which is under reviewing) and its gradient, I successfully get the correct gradient value by commenting out the dropout part of the model. However, to complete the training of this tiny model, I believe there are still missing pieces and I hope that we could have some discussion in this thread, so that we could have a clear roadmap when contributing for the training support~
To update the weights in each iteration, it seems that we need to convert the returned gradients back to pytorch’s tensor. It’s intolerable to copy the parameters back and forth in a high performance training. Therefore, we need to save and update the state of the weights in tvm. I wonder if the community have any thoughts on this.
In my opinion, we’d better not make the gradient update or optimizer kernels in the same relay graph as the rest of the training. Though the idea of a unified IR graph may be beautiful, it would be hard to implement the distributed training support since we need to add the communication ops to the relay graph now. If we could separate the pure forward and backward graph and the weight updating part, we could, for example, add a new tvm support to horovod and easily add the data parallel functionality.
As we already got the PRNG support in relay, we could write a train dropout with an extra key argument. And to add the train dropout to the graph, especially from the frontend parser, we need some way to identify if tvm is in training mode or inference mode and do the parsing base on the mode. Also, as the training mode often comes with side effect, we need to settle how to maintain the state for them (e.g. the prng key of dropout, moving mean and variance for batchnorm). One possible solution is saving the state out side the graph and use them as inputs. This may be coherent with separating the gradient updates and other calculations.
Thank you for your time on this post~ I genuinely hope we could have some conclusion with these two parts.