I was trying to create a function, to generalize creating basic IRModules for testing, when I came across a bug. I created a function
def get_unary_mod(method, shape):
@tvm.script.ir.ir_module
class expected:
@R.function
def main(
in1: R.Tensor(shape, "float32")
) -> R.Tensor(shape, "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor(shape, dtype="float32") = method(in1)
R.output(lv)
return lv
return expected
that would take a method (like R.log, R.tan etc.), and a shape (of all Tensors). I was trying to expand it to support R.reshape (with a constant shape to reshape to), by adding an output shape parameter:
def get_unary_mod(method, shape, o_shape):
@tvm.script.ir.ir_module
class expected:
@R.function
def main(
in1: R.Tensor(shape, "float32")
) -> R.Tensor(o_shape, "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor(o_shape, dtype="float32") = method(in1)
R.output(lv)
return lv
return expected
But doing so fails to interpret the shape variables. I was able to use other parameters without a problem, it seems like the different shapes inside the function cause some issues in the parser, not the parameters themselves.
Expected behavior
I wanted to run the case:
s1, s2 = (2, 3), (6,)
def met(in1):
return R.reshape(in1, R.shape(list(s2)))
mod = get_unary_mod(met, s1, s2)
mod.show()
That should print:
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(in1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((6,), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((6,), dtype="float32") = R.reshape(in1, R.shape([6]))
R.output(lv)
return lv
Actual behavior
But the run failed with the message:
error: Undefined variable: shape
--> root-dir/bug.py:15:22
|
15 | in1: R.Tensor(shape, "float32")
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
backtrace message:
Traceback (most recent call last):
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 98, in eval_struct_info_proxy
annotation = self.eval_expr(node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 466, in eval_expr
return eval_expr(self, node, var_values)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 438, in eval_expr
return ExprEvaluator.eval(parser, value_table, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 118, in eval
result = self._visit(node) # pylint: disable=protected-access
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 254, in _visit
fields[field] = self._visit(attr)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 215, in _visit
return [self._visit(n) for n in node]
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 215, in <listcomp>
return [self._visit(n) for n in node]
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 221, in _visit
raise ParserError(node, f"Undefined variable: {node.id}")
tvm.script.parser.core.error.ParserError: Undefined variable: shape
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "root-dir/relax_test.py", line 30, in <module>
mod = get_unary_mod(met, s1, s2)
File "root-dir/relax_test.py", line 12, in get_unary_mod
class expected:
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/entry.py", line 56, in ir_module
return decorator_wrapper(mod)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/entry.py", line 50, in decorator_wrapper
m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/entry.py", line 100, in parse
parser.parse(extra_vars=extra_vars)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 379, in parse
self.visit(node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 597, in visit
func(node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/doc.py", line 257, in generic_visit
self.visit(value)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 585, in visit
self.visit(item)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 597, in visit
func(node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 672, in visit_ClassDef
_dispatch_wrapper(func)(self, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 309, in _wrapper
return func(self, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/parser.py", line 63, in _visit_class_def
global_var = self.visit_tvm_declare_function(stmt)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 654, in visit_tvm_declare_function
return _dispatch(self, "tvm_declare_function")(self, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 309, in _wrapper
return func(self, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 232, in visit_tvm_declare_function
collect_symbolic_var_from_params(self, node)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 153, in collect_symbolic_var_from_params
param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 101, in eval_struct_info_proxy
self.report_error(node, str(err))
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 568, in report_error
raise diag_err
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 559, in report_error
self.diag.error(node, msg)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/diagnostics.py", line 257, in error
self.ctx.render()
File "root-dir/.venv/lib/python3.10/site-packages/tvm/ir/diagnostics/__init__.py", line 119, in render
_ffi_api.DiagnosticContextRender(self)
File "root-dir/.venv/lib/python3.10/site-packages/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "root-dir/.venv/lib/python3.10/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.DiagnosticError: Traceback (most recent call last):
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::__mk_TVM11::{lambda(tvm::DiagnosticContext)#1}>(tvm::__mk_TVM11::{lambda(tvm::DiagnosticContext)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
0: tvm::DiagnosticContext::Render()
File "tvm-root/src/ir/diagnostic.cc", line 131
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.
Environment
This was run on Ubuntu 22.04 LTS, on TVM 0.17.dev0, after commit #1eac1785, on Python 3.10.12
Steps to reproduce
import tvm
import tvm.relax.frontend.nnef
from tvm.script import ir as I
from tvm.script import relax as R
def get_unary_mod(method, shape, o_shape):
@tvm.script.ir.ir_module
class expected:
@R.function
def main(
in1: R.Tensor(shape, "float32")
) -> R.Tensor(o_shape, "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor(o_shape, dtype="float32") = method(in1)
R.output(lv)
return lv
return expected
s1, s2 = (2, 3), (6,)
def met(in1):
return R.reshape(in1, R.shape(list(s2)))
mod = get_unary_mod(met, s1, s2)
mod.show()
Workaround
By substituting the variables to values, TVM has no issue creating a Module. But with multiple variables used for shapes, none are recognizes are variables anymore. I could declare them as script level global variables, in which case it worked, but I wanted to modify them locally.
I managed to solve this, by declaring the shape and o_shape variables global:
def get_unary_mod(method, s1, s2):
global shape, o_shape
shape, o_shape = s1, s2
@tvm.script.ir.ir_module
class expected:
@R.function
def main(
in1: R.Tensor(shape, "float32")
) -> R.Tensor(o_shape, "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor(o_shape, dtype="float32") = method(in1)
R.output(lv)
return lv
return expected
However, this does not feel ideal, as Python-side the shape and o_shape parameters are visible in the scope of the function, just the parser fails to recognize them as variables.