For the BatchNorm1d operator, if the attribute track_running_stats=False
, TVM will crash and throw: Exception: warning unhandled case: <class ‘NoneType’>
Question:
What caused the introduction of NoneType
during the conversion of the BatchNorm1d operator?
Steps to reproduce
import torch
from tvm import relay
m = torch.nn.BatchNorm1d(3,track_running_stats=False) # crash
input_data=[torch.randn([4, 3], dtype=torch.float32)]
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([4, 3]))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
Trackback
Traceback (most recent call last):
File "test.py", line 10, in <module>
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 4969, in from_pytorch
outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/pytorch.py", line 4242, in convert_operators
relay_out = set_span(relay_out, self.source_map[op_node])
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1216, in set_span
return _SpanFiller(span).fill(sym)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1155, in fill
return self.visit(sym)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1075, in visit
return super().visit(expr)
File "/workplace/software/tvm/tvm/python/tvm/relay/expr_functor.py", line 60, in visit
res = self.visit_tuple_getitem(expr)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1119, in visit_tuple_getitem
op, self.visit(op.tuple_value), op.index, None, self._span
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1075, in visit
return super().visit(expr)
File "/workplace/software/tvm/tvm/python/tvm/relay/expr_functor.py", line 48, in visit
res = self.visit_call(expr)
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1091, in visit_call
new_args = [self.visit(arg) for arg in call.args]
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1091, in <listcomp>
new_args = [self.visit(arg) for arg in call.args]
File "/workplace/software/tvm/tvm/python/tvm/relay/frontend/common.py", line 1075, in visit
return super().visit(expr)
File "/workplace/software/tvm/tvm/python/tvm/relay/expr_functor.py", line 76, in visit
raise Exception(f"warning unhandled case: {type(expr)}")
Exception: warning unhandled case: <class 'NoneType'>