[Unity] TVMScript can't copy multiple shape parameters from function

I was trying to create a function, to generalize creating basic IRModules for testing, when I came across a bug. I created a function

def get_unary_mod(method, shape):
    @tvm.script.ir.ir_module
    class expected:
        @R.function
        def main(
                in1: R.Tensor(shape, "float32")
        ) -> R.Tensor(shape, "float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor(shape, dtype="float32") = method(in1)
                R.output(lv)
            return lv

    return expected

that would take a method (like R.log, R.tan etc.), and a shape (of all Tensors). I was trying to expand it to support R.reshape (with a constant shape to reshape to), by adding an output shape parameter:

 def get_unary_mod(method, shape, o_shape):
    @tvm.script.ir.ir_module
    class expected:
        @R.function
        def main(
                in1: R.Tensor(shape, "float32")
        ) -> R.Tensor(o_shape, "float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor(o_shape, dtype="float32") = method(in1)
                R.output(lv)
            return lv

    return expected

But doing so fails to interpret the shape variables. I was able to use other parameters without a problem, it seems like the different shapes inside the function cause some issues in the parser, not the parameters themselves.

Expected behavior

I wanted to run the case:

s1, s2 = (2, 3), (6,)

def met(in1):
    return R.reshape(in1, R.shape(list(s2)))
mod = get_unary_mod(met, s1, s2)

mod.show()

That should print:

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(in1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((6,), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((6,), dtype="float32") = R.reshape(in1, R.shape([6]))
            R.output(lv)
        return lv

Actual behavior

But the run failed with the message:

error: Undefined variable: shape
 --> root-dir/bug.py:15:22
    |  
 15 |                  in1: R.Tensor(shape, "float32")
    |                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
backtrace message:
Traceback (most recent call last):
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 98, in eval_struct_info_proxy
    annotation = self.eval_expr(node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 466, in eval_expr
    return eval_expr(self, node, var_values)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 438, in eval_expr
    return ExprEvaluator.eval(parser, value_table, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 118, in eval
    result = self._visit(node)  # pylint: disable=protected-access
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 254, in _visit
    fields[field] = self._visit(attr)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 215, in _visit
    return [self._visit(n) for n in node]
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 215, in <listcomp>
    return [self._visit(n) for n in node]
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/evaluator.py", line 221, in _visit
    raise ParserError(node, f"Undefined variable: {node.id}")
tvm.script.parser.core.error.ParserError: Undefined variable: shape

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "root-dir/relax_test.py", line 30, in <module>
    mod = get_unary_mod(met, s1, s2)
  File "root-dir/relax_test.py", line 12, in get_unary_mod
    class expected:
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/entry.py", line 56, in ir_module
    return decorator_wrapper(mod)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/entry.py", line 50, in decorator_wrapper
    m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/entry.py", line 100, in parse
    parser.parse(extra_vars=extra_vars)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 379, in parse
    self.visit(node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 597, in visit
    func(node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/doc.py", line 257, in generic_visit
    self.visit(value)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 585, in visit
    self.visit(item)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 597, in visit
    func(node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 672, in visit_ClassDef
    _dispatch_wrapper(func)(self, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 309, in _wrapper
    return func(self, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/ir/parser.py", line 63, in _visit_class_def
    global_var = self.visit_tvm_declare_function(stmt)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 654, in visit_tvm_declare_function
    return _dispatch(self, "tvm_declare_function")(self, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 309, in _wrapper
    return func(self, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 232, in visit_tvm_declare_function
    collect_symbolic_var_from_params(self, node)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 153, in collect_symbolic_var_from_params
    param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/relax/parser.py", line 101, in eval_struct_info_proxy
    self.report_error(node, str(err))
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 568, in report_error
    raise diag_err
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/parser.py", line 559, in report_error
    self.diag.error(node, msg)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/script/parser/core/diagnostics.py", line 257, in error
    self.ctx.render()
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/ir/diagnostics/__init__.py", line 119, in render
    _ffi_api.DiagnosticContextRender(self)
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "root-dir/.venv/lib/python3.10/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.DiagnosticError: Traceback (most recent call last):
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::__mk_TVM11::{lambda(tvm::DiagnosticContext)#1}>(tvm::__mk_TVM11::{lambda(tvm::DiagnosticContext)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  0: tvm::DiagnosticContext::Render()
  File "tvm-root/src/ir/diagnostic.cc", line 131
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

Environment

This was run on Ubuntu 22.04 LTS, on TVM 0.17.dev0, after commit #1eac1785, on Python 3.10.12

Steps to reproduce

import tvm
import tvm.relax.frontend.nnef

from tvm.script import ir as I
from tvm.script import relax as R


def get_unary_mod(method, shape, o_shape):
    @tvm.script.ir.ir_module
    class expected:
        @R.function
        def main(
                in1: R.Tensor(shape, "float32")
        ) -> R.Tensor(o_shape, "float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor(o_shape, dtype="float32") = method(in1)
                R.output(lv)
            return lv

    return expected


s1, s2 = (2, 3), (6,)

def met(in1):
    return R.reshape(in1, R.shape(list(s2)))
mod = get_unary_mod(met, s1, s2)

mod.show()

Workaround

By substituting the variables to values, TVM has no issue creating a Module. But with multiple variables used for shapes, none are recognizes are variables anymore. I could declare them as script level global variables, in which case it worked, but I wanted to modify them locally.

I managed to solve this, by declaring the shape and o_shape variables global:

def get_unary_mod(method, s1, s2):
    global shape, o_shape
    shape, o_shape = s1, s2
    @tvm.script.ir.ir_module
    class expected:
        @R.function
        def main(
                in1: R.Tensor(shape, "float32")
        ) -> R.Tensor(o_shape, "float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tensor(o_shape, dtype="float32") = method(in1)
                R.output(lv)
            return lv

    return expected

However, this does not feel ideal, as Python-side the shape and o_shape parameters are visible in the scope of the function, just the parser fails to recognize them as variables.

This is because shape is only used in type annotations and not within the function body (where o_shape is used instead).

The TVMScript parser collects Python nonlocal_vars. However, due to limitations of the Python interpreter, type annotations are not considered nonlocal_vars.

All right, I thought it was a problem with scopes, I didn’t know that the type annotations were scoped differently. Thank you for your answer.