I found some broken cases for higher order AD when dealing with if or tuple expressions.
import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, y)
net = relay.log(net)
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
Another case:
import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
net = relay.Tuple([x, y])
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
The problem is, ReverseAD transform every expression to (forward, ref_reverse) form. In call expressions, there is a explicit TupleGetItem to get the first field in the tuple as the argument. However, this is not handled in other cases such If or Tuple. As a result, the tuple is used as the condition for If which causes error in type inference.
The zeroth case is an error, and I am fixing it rn.
The first case require a bit more thought: I am thinking of making an extensible AD interface so we can support other kind of data type, and we can use that to support tuple. In the original design the AD interface take a whole program from tuple of tensor to tensor.
import tvm
import tvm.relay as relay
acc = relay.Var("acc", relay.TensorType(shape=(1, 16)))
x = relay.Var("x", relay.TensorType(shape=(1, 16)))
w = relay.Var("w", relay.TensorType(shape=(1, 16)))
fn = relay.Function([acc, x], relay.add(acc, relay.multiply(x, w)))
m = relay.module.Module()
prelude = relay.prelude.Prelude(m)
xs = relay.Var("xs", prelude.l(relay.TensorType(shape=(1, 16))))
init = relay.zeros((1, 16), dtype='float32')
F = prelude.foldl(fn, init, xs)
F = relay.Function([w, xs], F)
main_ = relay.GlobalVar('main')
m[main_] = F
print(m[main_])
F = relay.ir_pass.gradient(F, m)
m[main_] = F
print(F)
Iām trying to fold over a list of data. It will transform xs: list[Tensor] into list[Tensor, ref] and cause error in type infer.
In this case, only gradients of weights are needed, instead of input data.
@jroesch Error annotation seems also buggy in this case. It prints TVMError: Error(s) have occurred. We have annotated the program with them: but nothing is annotated
In this case, only gradients of weights are needed, instead of input data.
If the gradient of input data is not used, dead code elimination should remove them. the dce rn dont respect effect and I am trying to get time to fix that.