Maximum recursion depth exceeded error will trying to use intrinsic

Hello, I am trying to write a conv2d computation and use tensorize. Here is my code: import numpy as np

import tvm
from tvm import te
from tvm import autotvm
from tvm import topi
from zte.enviroment import get_env
from zte.intrin import gemm

data = tvm.te.placeholder((1, 64, 112, 112), dtype="int16", name="data")
kernel = tvm.te.placeholder((64, 64, 3, 3), dtype="int16", name="kernel")

out_w = 110
out_h = 110
out_dtype = "int32"
oshape = (1, 64, 110, 110)

# define the conv2d operator over the padded data
c = te.reduce_axis((0, 64), name="c")
k_h = te.reduce_axis((0, 3), name="k_h")
k_w = te.reduce_axis((0, 3), name="k_w")
hstride, wstride = 1, 1
res = te.compute(
    oshape,
    lambda n, o, h, w: te.sum(
        data[n, c, h * hstride + k_h, w * wstride + k_w].astype(out_dtype)
        * kernel[o, c, k_h, k_w].astype(out_dtype),
        axis=[c, k_h, k_w],
    ),
    name="res",
    tag="conv2d_dense",
)

s = tvm.te.create_schedule(res.op)

env = get_env()

n_o, o_o, h_o, w_o = s[res].op.axis
c_in, h_1, w_1 = s[res].op.reduce_axis

h_w_data = s[res].fuse(h_o, w_o)
h_w_kernel = s[res].fuse(k_h, k_w)
s[res].reorder(n_o, h_w_kernel, h_w_data, o_o, c_in)

xo, yo, xi, yi = s[res].tile(o_o, c_in, x_factor=16, y_factor=16)
x_y_fuse = s[res].fuse(xo, yo)
data_out, data_in = s[res].split(h_w_data, 256)

s[res].tensorize(h_w_kernel, gemm(env, 0, 16, 16, 16))

code = tvm.lower(s, [data, kernel, res], simple_mode=True)

And this is the intrinsic:

from __future__ import absolute_import as _abs

import tvm
from tvm import te


def gemm(env, GS, GM, GK, GN):
    """Matrix-matrix multiply intrinsic

    Parameters
    ----------
    env : Environment
        The Environment

    mock : bool
        Whether create a mock version.
    """
    wgt_lanes = int((env.WF_ELEM_BITS / 2) // env.WF_WIDTH)
    assert wgt_lanes >= env.BLOCK_NROW  * env.BLOCK_NCOL
    wgt_shape = (env.BLOCK_NROW * GM, env.BLOCK_NCOL * GK)

    in_lanes = int((env.WF_ELEM_BITS / 2) // env.WF_WIDTH)
    assert in_lanes >= env.BLOCK_NROW * env.BLOCK_NCOL
    in_shape = (env.BLOCK_NCOL * GK, env.BLOCK_NROW * GN)

    out_lanes = int(env.MID_ELEM_BITS // env.MID_WIDTH)
    assert out_lanes >= env.BLOCK_NROW * env.BLOCK_NCOL
    out_shape = (env.BLOCK_NROW * GM, env.BLOCK_NCOL * GN)

    wgt = te.placeholder(
        (wgt_shape[0], wgt_shape[1]), dtype="int%d" % env.WF_WIDTH, name=env.wf_scope
    )
    inp = te.placeholder(
        (in_shape[0], in_shape[1]), dtype="int%d" % env.WF_WIDTH, name=env.wf_scope
    )
    k = te.reduce_axis((0, wgt_shape[1]), name="k")
    out_dtype = "int%d" % env.MID_WIDTH
    out = te.compute(
        (out_shape[0], out_shape[1]),
        lambda i, j: te.sum(inp[i, k].astype(out_dtype) * wgt[j, k].astype(out_dtype), axis=[k]),
        name="out",
    )

    wgt_layout = tvm.tir.decl_buffer(
        wgt.shape,
        wgt.dtype,
        env.wf_scope,
        scope=env.wf_scope,
        offset_factor=wgt_lanes,
        data_alignment=wgt_lanes,
    )
    inp_layout = tvm.tir.decl_buffer(
        inp.shape,
        inp.dtype,
        env.wf_scope,
        scope=env.wf_scope,
        offset_factor=in_lanes,
        data_alignment=in_lanes,
    )
    out_layout = tvm.tir.decl_buffer(
        out.shape,
        out.dtype,
        env.mid_scope,
        scope=env.mid_scope,
        offset_factor=out_lanes,
        data_alignment=out_lanes,
    )

    def intrin_func(ins, outs):
        """Matrix-matrix multiply intrinsic function"""
        dinp, dwgt = ins
        dout = outs[0]

        """Generate matrix-matrix multiply VTA instruction"""
        irb = tvm.tir.ir_builder.create()
        dev = env.dev
        irb.scope_attr(dev.zte_axis, "zte_uop_scope", tvm.tir.StringImm("zte_gemm"))
        irb.emit(
            tvm.tir.call_intrin(
                "int16",
                "zte_gemm_fp16",
                dout.access_ptr("rw", "int16"),
                dinp.access_ptr("r", "int16"),
                dwgt.access_ptr("r", "int16"),
                GS,
                GK,
                GN,
                GM
                )
            )
            
        return irb.get()


    return te.decl_tensor_intrin(
        out.op, intrin_func, name="GEMM", binds={inp: inp_layout, wgt: wgt_layout, out: out_layout}
    )

This is the error I met:

Traceback (most recent call last):
  File "/home/tonywu/Documents/tvm/zte/test_con2d.py", line 53, in <module>
    s[res].tensorize(h_w_kernel, gemm(env, 0, 16, 16, 16))
  File "/home/tonywu/Documents/tvm/zte/intrin.py", line 113, in gemm
    return te.decl_tensor_intrin(
  File "/home/tonywu/Documents/tvm/python/tvm/te/tensor_intrin.py", line 139, in decl_tensor_intrin
    body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :])
  File "/home/tonywu/Documents/tvm/zte/intrin.py", line 94, in intrin_func
    dev = env.dev
  File "/home/tonywu/Documents/tvm/zte/enviroment.py", line 135, in dev
    self._dev_ctx = DevContext(self)
  File "/home/tonywu/Documents/tvm/zte/enviroment.py", line 40, in __init__
    self.gemm = intrin.gemm(env, 0, 16, 16, 16)
  File "/home/tonywu/Documents/tvm/zte/intrin.py", line 113, in gemm
    return te.decl_tensor_intrin(
  File "/home/tonywu/Documents/tvm/python/tvm/te/tensor_intrin.py", line 139, in decl_tensor_intrin
    body = fcompute(binds_list[: len(inputs)], binds_list[len(inputs) :])
  File "/home/tonywu/Documents/tvm/zte/intrin.py", line 97, in intrin_func
    tvm.tir.call_intrin(
  File "/home/tonywu/Documents/tvm/python/tvm/tir/op.py", line 99, in call_intrin
    return Call(dtype, func_name, convert(args), span)
  File "/home/tonywu/Documents/tvm/python/tvm/tir/expr.py", line 1157, in __init__
    raise ValueError(
ValueError: Cannot handle str op argument zte_gemm_fp16. This function only handles str argument with the tir namespace. If you are certain about the intrinsic name, pass in Op.get(name) instead
Exception ignored in: <function ObjectBase.__del__ at 0x7f01c570fc10>

And more error

  File "/home/tonywu/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 119, in _make_tvm_args
    values[i].v_handle = arg.handle
  File "/home/tonywu/Documents/tvm/python/tvm/runtime/object.py", line 63, in __getattr__
    return _ffi_node_api.NodeGetAttr(self, name)
  File "/home/tonywu/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 223, in __call__
    values, tcodes, num_args = _make_tvm_args(args, temp_args)
  File "/home/tonywu/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 119, in _make_tvm_args
    values[i].v_handle = arg.handle
  File "/home/tonywu/Documents/tvm/python/tvm/runtime/object.py", line 63, in __getattr__
    return _ffi_node_api.NodeGetAttr(self, name)
  File "/home/tonywu/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 223, in __call__
    values, tcodes, num_args = _make_tvm_args(args, temp_args)
RecursionError: maximum recursion depth exceeded

Thanks a lot if u have any advice!