Register gradient of Split operator in relay

I am trying to register the gradient for the split operatror.

This is the code gradient that I am trying to use.

@register_gradient("split")
def split_grad(orig, grad):
    grad = relay.concatenate(grad, axis=orig.attrs.axis)
    return [grad]

When i try to register the gradient, I am getting this error.

TypeError: 'Var' object is not iterable

But if I comment out the concatenate line and directly pass the grad as-is, the gradient transform is done and a gradient IR module is produced but it gives me this error when i try to build the gradient module

Check failed: (node != nullptr) is false: Expected type to be relay.TensorType, but get TupleType

I have attached full reproducible code below. Is it possible to get the gradients of the split operator in this manner?

import tvm
from tvm import relay
from tvm.relay.testing import run_infer_type
import numpy as np
from tvm.contrib import graph_executor

from tvm.relay.op import register_gradient
@register_gradient("split")
def split_grad(orig, grad):
    grad = relay.concatenate(grad, axis=orig.attrs.axis)
    return [grad]

shape = (3,3)
dtype = "float32"

x_np = np.random.uniform(low=2, high=2, size=shape).astype("float32")
tp = relay.TensorType(shape, dtype=dtype)

x = relay.var("x", type_annotation=tp)
v1 = relay.split(x, indices_or_sections=[2], axis=1)

call = relay.concatenate(v1, axis=1)
exp = relay.Function([x],call)
exp = run_infer_type(exp)

bwd_exp = relay.transform.gradient(exp, mode='first_order')
bwd_mod = tvm.IRModule.from_expr(bwd_exp)
bwd_mod = relay.transform.InferType()(bwd_mod)

bwd_mod = relay.transform.ToGraphNormalForm()(bwd_mod)
print(bwd_mod["main"])

target = 'llvm'
ctx = tvm.device(str(target), 0)

with tvm.transform.PassContext(opt_level=3):
    lib = tvm.relay.build(bwd_mod,
                          target=target,
                          params={})
    params = lib.params

compiled_module = graph_executor.GraphModule(lib["default"](ctx))
compiled_module.set_input("x",x_np)
compiled_module.run()

outs = [compiled_module.get_output(i) for i in range(compiled_module.get_num_outputs())]

print(x_np)
print(outs)

This seems to have worked. Code is not tested fully.

@register_gradient("split")
def split_grad(orig, grad):
    length = len(orig.checked_type.fields)
    grad = relay.concatenate([x for x in relay.TupleWrapper(grad,length)],orig.attrs.axis)
    return [grad]