I'm a beginner in TVM, and I'm trying to represent the AlexNet model using Relay IR. However, I encountered an error during compilation, and I'm unable to identify the root cause of the problem. I would appreciate any help from the community. Thank you.
import tvm
from tvm import relay
from tvm.relay import testing
from tvm.relay.testing import layers
from tvm.contrib import graph_executor
import numpy as np
def alexnet(input:relay.expr.Expr) -> relay.expr.Expr:
# 第0层卷积 1 96 11 11 4 1
layer0 = relay.nn.conv2d(input, relay.var("layer.0.weight", shape=(96, 1, 11, 11)), kernel_size=(11, 11), strides=(4, 4), padding=(1, 1))
# [1, 96, 54, 54]
layer0_out = relay.nn.bias_add(layer0, relay.var("layer.0.bias"))
# 第1层relu激活函数
layer1 = relay.nn.relu(layer0_out)
# 第2层最大池化层 3 3 2 0 1 false
layer2 = relay.nn.max_pool2d(layer1, pool_size=(3, 3), strides=(2, 2), padding=(0, 0), ceil_mode=False)
# 第3层卷积 96 256 5 5 1 2
layer3 = relay.nn.conv2d(layer2, relay.var("layer.3.weight", shape=(256, 96, 5, 5)), kernel_size=(5, 5), strides=(1, 1), padding=(2, 2))
layer3_out = relay.nn.bias_add(layer3, relay.var("layer.3.bias"))
# 第4层relu激活函数
layer4 = relay.nn.relu(layer3_out)
# 第5层最大池化层 3 3 2 0 1 false
layer5 = relay.nn.max_pool2d(layer4, pool_size=(3, 3), strides=(2, 2), padding=(0, 0), ceil_mode=False)
# 第6层卷积 256 384 3 3 1 1
layer6 = relay.nn.conv2d(layer5, relay.var("layer.6.weight", shape=(384, 256, 3, 3)), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1))
layer6_out = relay.nn.bias_add(layer6, relay.var("layer.6.bias"))
# 第6层relu激活函数
layer7 = relay.nn.relu(layer6_out)
# 第8层卷积 384 384 3 3 1 1
layer8 = relay.nn.conv2d(layer7, relay.var("layer.8.weight", shape=(384, 384, 3, 3)), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1))
layer8_out = relay.nn.bias_add(layer8, relay.var("layer.8.bias"))
# 第9层relu激活函数
layer9 = relay.nn.relu(layer8_out)
# 第10层卷积 384 256 3 3 1 1
layer10 = relay.nn.conv2d(layer9, relay.var("layer.10.weight", shape=(256, 384, 3, 3)), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1))
layer10_out = relay.nn.bias_add(layer10, relay.var("layer.10.bias"))
# 11 relu
layer11 = relay.nn.relu(layer10_out)
# 12 maxpool 3 3 2 0 1 false
layer12 = relay.nn.max_pool2d(layer11, pool_size=(3, 3), strides=(2, 2), padding=(0, 0), ceil_mode=False)
# 13 flatten
layer13 = relay.nn.batch_flatten(layer12)
# torch.Size([1, 6400])
# 14 6400 4096 linear
# print(layer13.shape)
layer14_w = relay.var("layer.14.weight", shape=(4096, 6400))
layer14 = relay.nn.matmul(layer13, relay.transpose(layer14_w, axes=(1, 0))) # 【1, 4096】
# layer14 = relay.nn.dense(layer13, relay.var("layer.14.weight"))
# 加上偏置
layer14_out = relay.nn.bias_add(layer14, relay.var("layer.14.bias"), axis=-1)# shape=(1, 4096)
# layer14_out = relay.nn.bias_add(layer14, relay.var("layer.14.bias"))
# 加上偏置
# layer14_out = relay.nn.bias_add(layer14, relay.var("layer.14.bias"))
# 15 relu激活函数
layer15 = relay.nn.relu(layer14_out)
# 16 dropout 推理时不需要
layer16 = layer15
# 17 4096 4096 linear
layer17_w = relay.var("layer.17.weight", shape=(4096, 4096))
layer17 = relay.nn.matmul(layer16, relay.transpose(layer17_w, axes=(1, 0)))
# 加上偏置
layer17_out = relay.nn.bias_add(layer17, relay.var("layer.17.bias"), axis=-1)
# 18 relu激活函数
layer18 = relay.nn.relu(layer17_out)
# 19 dropout 推理时不需要
layer19 = layer18
# 20 4096 10 linear
layer20_w = relay.var("layer.20.weight", shape=(10, 4096))
# layer20 = relay.nn.dense(layer19, relay.var("layer.20.weight"))
layer20 = relay.nn.matmul(layer19, relay.transpose(layer20_w, axes=(1, 0)))
layer20_out = relay.nn.bias_add(layer20, relay.var("layer.20.bias"), axis=-1)
# 加上偏置
# layer20_out = relay.nn.bias_add(layer20, relay.var("layer.20.bias"))
return layer20_out
if __name__ == "__main__":
# 加载预训练的模型
torch.load("/home/wupengxin/alexnext.pth")
# net.load_state_dict(torch.load("alexnet.pt"))
# 定义输入数据
input = relay.var("input", shape=(1, 1, 28, 28), dtype="float32")
# 加载模型的参数
params = torch.load("/home/wupengxin/alexnext_params.pth")
# 转换为tvm的参数字典
relay_params = {"layer."+k: tvm.nd.array(v.cpu().numpy().astype('float32')) for k, v in params.items()}
# 打印参数
# [ print(k, relay_params[k].shape) for k in relay_params.keys()]
# 构造网络 返回expr.Expr
net = alexnet(input)
func = relay.Function(relay.analysis.free_vars(net), net)
func = relay.build_module.bind_params_by_name(func, relay_params)
# 打印relay的中间表示
# print(func)
target = tvm.target.Target("cuda")
# 指定运行的设备
dev = tvm.cuda(1)
# 编译
with tvm.transform.PassContext(opt_level=1):
lib = relay.build(func, target=target, params=relay_params)
print(lib)
err:
81: TVMFuncCall
80: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
79: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
78: void tvm::relay::backend::ExecutorCodegen::CallFunc<tvm::IRModule, tvm::relay::Function, tvm::runtime::String>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
77: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
76: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
75: tvm::transform::Pass::operator()(tvm::IRModule) const
74: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
73: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
72: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
71: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
70: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
69: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
68: tvm::transform::Pass::operator()(tvm::IRModule) const
67: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
66: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
65: _ZN3tvm7runtime13PackedFuncObj
64: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
63: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
62: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
61: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
60: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
59: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
58: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
57: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
56: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
55: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
54: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
53: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
52: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
51: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
50: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
49: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
48: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
47: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
46: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
45: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
44: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
43: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
42: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
41: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
40: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
39: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
38: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
37: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
36: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
35: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
34: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
33: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
32: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
31: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
30: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
29: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
28: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
27: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
26: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
25: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
24: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
23: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
22: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
21: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
20: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
19: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
18: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
17: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
16: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
15: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
14: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
13: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
12: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
11: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&)
10: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
9: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::te::Tensor, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, tvm::GlobalVarSupply, bool)
8: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, tvm::GlobalVarSupply, bool)
7: tvm::ScheduleToModule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, tvm::GlobalVarSupply)
6: tvm::te::InferBound(tvm::te::Schedule const&)
5: tvm::te::InferRootBound(tvm::te::Stage const&, tvm::te::GraphContext const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*)
4: tvm::te::PassUpDomain(tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::arith::IntSet, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::arith::IntSet> > >*)
3: tvm::te::PassUpDomain(tvm::te::FuseNode const*, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, tvm::arith::IntSet const&, tvm::arith::IntSet*, tvm::arith::IntSet*)
2: tvm::indexdiv(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
1: tvm::floordiv(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
0: tvm::runtime::Optional<tvm::PrimExpr> tvm::arith::TryConstFold<tvm::tir::FloorDiv>(tvm::PrimExpr, tvm::PrimExpr)
File "tvm/src/arith/const_fold.h", line 285
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information
---------------------------------------------------------------
Check failed: pb->value != 0 (0 vs. 0) : Divide by zero