`relay.transform.gradient` successfully get IRs but fail to compile

I am trying to use TVM to get gradients of mobilenetv2 from pytorch, after disabling dropout and bn layers, the fwd IR now can be transformed to [fwd + bwd] IR without problem. However, when I try to use relay.build to compile the IR, it raises unexpected errors. Can anyone advise how to workthrough this issue?

import torch
import torch as th
import torch.nn as nn
from torchvision import models
import torch.onnx 

import numpy as np

import tvm
from tvm import relay
from tvm import relay, auto_scheduler
from tvm.relay import testing


def disable_dropout_bn(module):
    module_output = module
    if isinstance(module, (nn.BatchNorm2d, nn.Dropout)):
        module_output = nn.Identity()
    for name, child in module.named_children():
        module_output.add_module(name, disable_dropout_bn(child))
    del module
    return module_output

model = models.MobileNetV2().features[:2] # only take two layers for simplicity
model = disable_dropout_bn(model)

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

print("==" * 25, "Scripted Model", "==" * 25)
print(scripted_model.code)

input_name = "input0"
shape_list = [(input_name, input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list, use_parser_friendly_name=True)

mod = tvm.transform.Sequential([
    relay.transform.PartialEvaluate(),
    relay.transform.DeadCodeElimination(),
    relay.transform.ToGraphNormalForm(),
])(mod)
print("==" * 25, "TVM IRs", "==" * 25)
print(mod)
"""
def @main(%input0: Tensor[(1, 3, 224, 224), float32], %v0_0_weight: Tensor[(32, 3, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 16, 112, 112), float32] {
  %0 = nn.conv2d(%input0, %v0_0_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 112, 112), float32] */;
  %1 = clip(%0, a_min=0f, a_max=6f) /* ty=Tensor[(1, 32, 112, 112), float32] */;
  %2 = reshape(%v1_conv_0_0_weight, newshape=[32, 1, 3, 3]) /* ty=Tensor[(32, 1, 3, 3), float32] */;
  %3 = nn.conv2d(%1, %2, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 112, 112), float32] */;
  %4 = clip(%3, a_min=0f, a_max=6f) /* ty=Tensor[(1, 32, 112, 112), float32] */;
  nn.conv2d(%4, %v1_conv_1_weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 112, 112), float32] */
}
"""
target = "llvm"
lib = relay.build(mod, target=target, params=params)
print("build [fwd] pass successful")

mod = relay.transform.InferType()(mod)
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_ir)
print("==" * 25, "TVM BWD IRs", "==" * 25)
print(bwd_mod['main'])
print("build [fwd + bwd] IR successful")
"""
fn (%input0: Tensor[(1, 3, 224, 224), float32], %v0_0_weight: Tensor[(32, 3, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 16, 112, 112), float32], (Tensor[(1, 3, 224, 224), float32], Tensor[(32, 3, 3, 3), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
  .... /* too long, omitted */
  (%x18, %45)
}
"""

lib = relay.build(bwd_mod, target=target, params=params)
print("build [fwd + bwd] pass successful")
"""
Traceback (most recent call last):
  File "tvm_report_mbv2.py", line 90, in <module>
    lib = relay.build(bwd_mod, target=target, params=params)
  File "/home/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 357, in build
    executor_config, runtime_mod, params = bld_mod.build(
  File "/home/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 172, in build
    self._build(mod, target, target_host, executor, mod_name)
  File "/home/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  14: TVMFuncCall
  13: _ZNSt17_Function_handlerIFvN3tvm
  12: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  11: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&, tvm::runtime::String)
  10: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  9: tvm::transform::Pass::operator()(tvm::IRModule) const
  8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  16: TVMFuncCall
  15: _ZNSt17_Function_handlerIFvN3tvm
  14: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  13: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&, tvm::runtime::String)
  12: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  11: tvm::transform::Pass::operator()(tvm::IRModule) const
  10: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  0: bool tvm::relay::Conv2DTransposeRel<tvm::relay::Conv2DTransposeAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/ligeng/Workspace/tvm/src/relay/analysis/type_solver.cc", line 622
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [15:13:27] /home/ligeng/Workspace/tvm/src/relay/op/nn/convolution.h:1111: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])) is false: 
"""