I am trying to use TVM to get gradients of mobilenetv2 from pytorch, after disabling dropout and bn layers, the fwd IR now can be transformed to [fwd + bwd] IR without problem. However, when I try to use relay.build
to compile the IR, it raises unexpected errors. Can anyone advise how to workthrough this issue?
import torch
import torch as th
import torch.nn as nn
from torchvision import models
import torch.onnx
import numpy as np
import tvm
from tvm import relay
from tvm import relay, auto_scheduler
from tvm.relay import testing
def disable_dropout_bn(module):
module_output = module
if isinstance(module, (nn.BatchNorm2d, nn.Dropout)):
module_output = nn.Identity()
for name, child in module.named_children():
module_output.add_module(name, disable_dropout_bn(child))
del module
return module_output
model = models.MobileNetV2().features[:2] # only take two layers for simplicity
model = disable_dropout_bn(model)
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
print("==" * 25, "Scripted Model", "==" * 25)
print(scripted_model.code)
input_name = "input0"
shape_list = [(input_name, input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list, use_parser_friendly_name=True)
mod = tvm.transform.Sequential([
relay.transform.PartialEvaluate(),
relay.transform.DeadCodeElimination(),
relay.transform.ToGraphNormalForm(),
])(mod)
print("==" * 25, "TVM IRs", "==" * 25)
print(mod)
"""
def @main(%input0: Tensor[(1, 3, 224, 224), float32], %v0_0_weight: Tensor[(32, 3, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1, 16, 112, 112), float32] {
%0 = nn.conv2d(%input0, %v0_0_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 112, 112), float32] */;
%1 = clip(%0, a_min=0f, a_max=6f) /* ty=Tensor[(1, 32, 112, 112), float32] */;
%2 = reshape(%v1_conv_0_0_weight, newshape=[32, 1, 3, 3]) /* ty=Tensor[(32, 1, 3, 3), float32] */;
%3 = nn.conv2d(%1, %2, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 112, 112), float32] */;
%4 = clip(%3, a_min=0f, a_max=6f) /* ty=Tensor[(1, 32, 112, 112), float32] */;
nn.conv2d(%4, %v1_conv_1_weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 112, 112), float32] */
}
"""
target = "llvm"
lib = relay.build(mod, target=target, params=params)
print("build [fwd] pass successful")
mod = relay.transform.InferType()(mod)
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_ir)
print("==" * 25, "TVM BWD IRs", "==" * 25)
print(bwd_mod['main'])
print("build [fwd + bwd] IR successful")
"""
fn (%input0: Tensor[(1, 3, 224, 224), float32], %v0_0_weight: Tensor[(32, 3, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32], %v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 16, 112, 112), float32], (Tensor[(1, 3, 224, 224), float32], Tensor[(32, 3, 3, 3), float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
.... /* too long, omitted */
(%x18, %45)
}
"""
lib = relay.build(bwd_mod, target=target, params=params)
print("build [fwd + bwd] pass successful")
"""
Traceback (most recent call last):
File "tvm_report_mbv2.py", line 90, in <module>
lib = relay.build(bwd_mod, target=target, params=params)
File "/home/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 357, in build
executor_config, runtime_mod, params = bld_mod.build(
File "/home/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 172, in build
self._build(mod, target, target_host, executor, mod_name)
File "/home/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
14: TVMFuncCall
13: _ZNSt17_Function_handlerIFvN3tvm
12: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
11: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&, tvm::runtime::String)
10: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
9: tvm::transform::Pass::operator()(tvm::IRModule) const
8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
0: tvm::relay::TypeSolver::Solve() [clone .cold]
16: TVMFuncCall
15: _ZNSt17_Function_handlerIFvN3tvm
14: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
13: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&, tvm::runtime::String)
12: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
11: tvm::transform::Pass::operator()(tvm::IRModule) const
10: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
9: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
2: tvm::relay::TypeSolver::Solve()
1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
0: bool tvm::relay::Conv2DTransposeRel<tvm::relay::Conv2DTransposeAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
File "/home/ligeng/Workspace/tvm/src/relay/analysis/type_solver.cc", line 622
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (false) is false: [15:13:27] /home/ligeng/Workspace/tvm/src/relay/op/nn/convolution.h:1111:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])) is false:
"""