[RFC] Differentiable tensor expression (Create and verify backward op automatically)

As there are more and more demands on TVM’s training support, one of the most tedious but important work is to write backward implementation for operators. It may take great benefit if we can provide automation tools to help this process. Such tool can serve in two functionalities:

  • Automatically create backward definition from forward definition
  • Check gradient given forward and backward definition

Traditional deep learning framework (perhaps Theano except :wink: ) conduct auto back-propagation on op graph level, that is, they have to implement one backward op given one forward op. Theoretically there should be 10000 backward op definitions if they have 10000 forward ops.

For TVM however, there is an opportunity that we may conduct back-propagation on tensor expression level. Tensor expression operations are much less than whole neural network operators set, thus it will greatly reduce human work on higher level (relay op).

Backward tensor expression generator

Interface

Since tensor expression defines how to compute output from input symbolically, we can just try apply back-propagation rule to it. eg, we can provide utility interface like

def auto_backprop(inputs: List[Tensor], output: Tensor) -> (List[Tensor], Tensor):
    """
    Given input tensor list and output tensor, generate backward computation.
    - The inputs are the placeholder representing the gradient respect to original output and some other necessary original tensors.
    - The outputs are gradients respect to each of the original inputs.
    """
    pass

Now if we have already defined some forward computation, then we can extract a “default” backward computation definition:

x = te.placeholder((n, k))
y = te.placeholder((m, k))
z = te.compute((n, m), ...)
  
((grad_x, grad_y), grad_z_placeholder) = te.auto_backprop((x, y), z)
sched = te.create_schedule(grad_x.op)
# Do schedule and tune backward ops...

The transformation should happens before create_schedule(), since generally forward & backward definitions are different and may not share same optimization strategies.

We can wrap this sort of utility in topi and relay, where we can try best to provide default backward op definitions automatically without hand-written definition. Some pros and cons are listed below:

  • Pros
    • Avoid hand-written work for at least some portion of operations.
    • Auto generated definition maybe more robust on boundary behaviors and corner cases.
  • Cons
    • It is not all-powerfull. Not all operators can be automatically backward.
    • Some optimization hint may lose (backward of matmul is also matmul, backward of conv2d is also conv2d)

Transformation logic

At the beginning we may just focus on te.compute(), and do not support for tensor intrinsic / hybrid / extern.

  • te.compute()

    • Use simple matmul as an example
      te.compute((m, n), lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k)
      
      If we want to compute gradient respect to weight[w1][w2], we have to know how output is related to this weight position. Thus we “remap” the iter vars related to weight:
      j = w1, k = w2
      
      Then all iter vars in compute expression can be represented with [w1, w2] with affine transformations.
      tvm.sum(data[i, w2] * weight[w1, w2], axis=..)  (for i, j=w1)
      
      i is free variable inner, it can be seen that each weight[w1, w2] contribute to all output[i, w1] for each feasible i. For each i, the gradient of tvm.sum(...) respect to weight[w1, w2] is data[i, w2]. According to chain rule, the gradient of loss respect to weight[w1, w2] can be computed as
       tvm.sum(data[i, w2] * grad_output[i, w1], axis=i)
    
    • Actual back-propagation logic should carefully handle iter var relationships. For each occurance of target tensor to compute gradient in the expression, the feasible integer sets of each free iter var will get inferred based on iter var remapping. Given free vars fixed, compute gradient expression of output expression respect to target tensor position. Finally chain rule is applied to sum gradient expression among free var’s feasible set. Unsupported case should be detected explicitly.
  • te.scan() is also an interesting operation valuable to support back-propagation, with which we can get backward implementations of RNN/LSTM/GRU directly.

Gradient checking between forward && backward ops

Given forward and backward implementation pair, we can verify the correctness with approximate gradients. This help developer to detect implementation error on general and corner cases. One of the methods is well described in https://datascience-enthusiast.com/DL/Improving_DeepNeural_Networks_Gradient_Checking.html

Hey @wrongtest,

Thank you for the RFC! Just wondering how it compares with the previous AD RFC ([RFC] Bring in Tensor Expression Autodiff) ?

Thanks!

2 Likes

Glad to see autodiff is already in progress! I think this rfc can be withdrew since this is exactly what autodiff is doing.

Now I am very curious about current progress of autodiff with some questions.

  • If I have some common neural network structure such as resnet50 at hand, can I just use autodiff to get backward computation graph?
  • Is there some description about common ops which can be coveraged by autodiff?
  • Can te.scan() be supported?

CC: @yzhliu the major contributor of this feature

graph-wise I think you can refer to relay.transform.gradient and as you lower the differentiated graph, you may leverage the tensor-level autodiff (te.gradient). Though tensor gradients now are mostly manually written.

You may refer to test cases

currently it is not supported.

1 Like