As there are more and more demands on TVM’s training support, one of the most tedious but important work is to write backward implementation for operators. It may take great benefit if we can provide automation tools to help this process. Such tool can serve in two functionalities:
- Automatically create backward definition from forward definition
- Check gradient given forward and backward definition
Traditional deep learning framework (perhaps Theano except ) conduct auto back-propagation on op graph level, that is, they have to implement one backward op given one forward op. Theoretically there should be 10000 backward op definitions if they have 10000 forward ops.
For TVM however, there is an opportunity that we may conduct back-propagation on tensor expression level. Tensor expression operations are much less than whole neural network operators set, thus it will greatly reduce human work on higher level (relay op).
Backward tensor expression generator
Since tensor expression defines how to compute output from input symbolically, we can just try apply back-propagation rule to it. eg, we can provide utility interface like
def auto_backprop(inputs: List[Tensor], output: Tensor) -> (List[Tensor], Tensor): """ Given input tensor list and output tensor, generate backward computation. - The inputs are the placeholder representing the gradient respect to original output and some other necessary original tensors. - The outputs are gradients respect to each of the original inputs. """ pass
Now if we have already defined some forward computation, then we can extract a “default” backward computation definition:
x = te.placeholder((n, k)) y = te.placeholder((m, k)) z = te.compute((n, m), ...) ((grad_x, grad_y), grad_z_placeholder) = te.auto_backprop((x, y), z) sched = te.create_schedule(grad_x.op) # Do schedule and tune backward ops...
The transformation should happens before create_schedule(), since generally forward & backward definitions are different and may not share same optimization strategies.
We can wrap this sort of utility in topi and relay, where we can try best to provide default backward op definitions automatically without hand-written definition. Some pros and cons are listed below:
- Avoid hand-written work for at least some portion of operations.
- Auto generated definition maybe more robust on boundary behaviors and corner cases.
- It is not all-powerfull. Not all operators can be automatically backward.
- Some optimization hint may lose (backward of matmul is also matmul, backward of conv2d is also conv2d)
At the beginning we may just focus on
te.compute(), and do not support for tensor intrinsic / hybrid / extern.
- Use simple matmul as an example
If we want to compute gradient respect to
te.compute((m, n), lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k)
weight[w1][w2], we have to know how output is related to this weight position. Thus we “remap” the iter vars related to weight:
Then all iter vars in compute expression can be represented with [w1, w2] with affine transformations.
j = w1, k = w2
tvm.sum(data[i, w2] * weight[w1, w2], axis=..) (for i, j=w1)
iis free variable inner, it can be seen that each
weight[w1, w2]contribute to all
output[i, w1]for each feasible
i. For each
i, the gradient of
data[i, w2]. According to chain rule, the gradient of loss respect to
weight[w1, w2]can be computed as
tvm.sum(data[i, w2] * grad_output[i, w1], axis=i)
- Actual back-propagation logic should carefully handle iter var relationships. For each occurance of target tensor to compute gradient in the expression, the feasible integer sets of each free iter var will get inferred based on iter var remapping. Given free vars fixed, compute gradient expression of output expression respect to target tensor position. Finally chain rule is applied to sum gradient expression among free var’s feasible set. Unsupported case should be detected explicitly.
- Use simple matmul as an example
te.scan()is also an interesting operation valuable to support back-propagation, with which we can get backward implementations of RNN/LSTM/GRU directly.
Gradient checking between forward && backward ops
Given forward and backward implementation pair, we can verify the correctness with approximate gradients. This help developer to detect implementation error on general and corner cases. One of the methods is well described in https://datascience-enthusiast.com/DL/Improving_DeepNeural_Networks_Gradient_Checking.html