I am trying to register the gradient for the split operatror.
This is the code gradient that I am trying to use.
@register_gradient("split")
def split_grad(orig, grad):
grad = relay.concatenate(grad, axis=orig.attrs.axis)
return [grad]
When i try to register the gradient, I am getting this error.
TypeError: 'Var' object is not iterable
But if I comment out the concatenate line and directly pass the grad as-is, the gradient transform is done and a gradient IR module is produced but it gives me this error when i try to build the gradient module
Check failed: (node != nullptr) is false: Expected type to be relay.TensorType, but get TupleType
I have attached full reproducible code below. Is it possible to get the gradients of the split operator in this manner?
import tvm
from tvm import relay
from tvm.relay.testing import run_infer_type
import numpy as np
from tvm.contrib import graph_executor
from tvm.relay.op import register_gradient
@register_gradient("split")
def split_grad(orig, grad):
grad = relay.concatenate(grad, axis=orig.attrs.axis)
return [grad]
shape = (3,3)
dtype = "float32"
x_np = np.random.uniform(low=2, high=2, size=shape).astype("float32")
tp = relay.TensorType(shape, dtype=dtype)
x = relay.var("x", type_annotation=tp)
v1 = relay.split(x, indices_or_sections=[2], axis=1)
call = relay.concatenate(v1, axis=1)
exp = relay.Function([x],call)
exp = run_infer_type(exp)
bwd_exp = relay.transform.gradient(exp, mode='first_order')
bwd_mod = tvm.IRModule.from_expr(bwd_exp)
bwd_mod = relay.transform.InferType()(bwd_mod)
bwd_mod = relay.transform.ToGraphNormalForm()(bwd_mod)
print(bwd_mod["main"])
target = 'llvm'
ctx = tvm.device(str(target), 0)
with tvm.transform.PassContext(opt_level=3):
lib = tvm.relay.build(bwd_mod,
target=target,
params={})
params = lib.params
compiled_module = graph_executor.GraphModule(lib["default"](ctx))
compiled_module.set_input("x",x_np)
compiled_module.run()
outs = [compiled_module.get_output(i) for i in range(compiled_module.get_num_outputs())]
print(x_np)
print(outs)