TVM Gradient behaves like a No-Op using `tvm.relay.create_executor`

Hello,

My goal is to convert PyTorch functions to TVM, to perform differential testing across both implementations of the same functions. As part of my testing, I’d like to compare gradients as well.

To accomplish this, I convert using from_pytorch, take the TVM gradient using tvm.relay.transform.gradient, and execute the resulting function using the tvm.relay.create_executor method. See the code below for a minimal example:

        def module_gradient(model):
            seq = transform.Sequential([
                relay.transform.FoldConstant(),
                relay.transform.PartialEvaluate(),
                relay.transform.DeadCodeElimination(inline_once=True),
                relay.transform.PartialEvaluate(),
                relay.transform.DeadCodeElimination(inline_once=True),
                relay.transform.InferType()])

            typed_model = seq(model)
            main_var = typed_model.get_global_vars()[0]
            main_fun = typed_model.functions[main_var]
            grad_main = tvm.relay.transform.gradient(main_fun)
            return tvm.ir.IRModule.from_expr(grad_main)

        torch_model = lambda x: 2 * x
        xs = [torch.tensor([1.,2.,3.])]
        torchscript_model = torch.jit.trace(torch_model, xs)

        shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in list(torchscript_model.graph.inputs())]
        tvm_model, tvm_params = relay.frontend.from_pytorch(torchscript_model, shape_list)

        grad_tvm_model = module_gradient(tvm_model)

        target = tvm.target.cuda()
        ctx = tvm.gpu()
        tvm.relay.backend.compile_engine.get().clear()
        with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
            tvm_inputs = [tvm.relay.Constant(tvm.nd.array(x.detach().numpy(), ctx)) for (_, x) in zip(shape_list, xs) ]

            intrp = tvm.relay.create_executor(mod=grad_tvm_model, ctx=ctx, target=target)

            main_var = grad_tvm_model.get_global_vars()[0]
            main_fun = grad_tvm_model.functions[main_var]
            result = intrp.evaluate(tvm.relay.Call(main_fun, tvm_inputs))

            print(result[0])

I expect the above example to print [2,2,2] (since the gradient of lambda x: 2*x is lambda x: 2`), but instead it prints [2,4,6]. In other words, the application of the gradient seems to have had no effect.

Any pointers here would be appreciated!

Best, Edwin