Memory verification failed with Relax

I’m trying to create a simple MLP with the Relax NN library

import tvm.relax.frontend.nn as nn
from tvm.relax.frontend.nn import spec
import tvm
from tvm import relay, relax

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

model = MLP(input_size=1, hidden_size=2, output_size=1)

mod_spec = {
    "forward": {
        "x": spec.Tensor([1, 1], "float32")
    }
}

mod, params = model.export_tvm(spec=mod_spec)
target = tvm.target.Target("cuda", host="cuda")
ex = relax.build(mod, target=target)

However. i get the following err when running this:

root@tegra-ubuntu:~/app# python3 play.py 
Traceback (most recent call last):
  File "/root/app/play.py", line 27, in <module>
    ex = relax.build(mod, target=target)
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/driver/build_module.py", line 294, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff8870c598]
  [bt] (7) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x428) [0xffff8870d0b8]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff8870c598]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1c8) [0xffff8870aeac]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13674) [0xffff890e3674]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13294) [0xffff890e3294]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f1067c) [0xffff890e067c]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff883f36a8]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff8a24d050]
  Did you forget to bind?
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `T_transpose` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/opt/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def transpose1(A: T.Buffer((T.int64(1), T.int64(2)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(1)), "float32")):
    T.func_attr({"target": T.target({"arch": "sm_87", "host": {"arch": "sm_87", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    for ax0 in range(2):
        T_transpose_1 = T.Buffer((T.int64(2),), data=T_transpose.data)
        A_1 = T.Buffer((T.int64(2),), data=A.data)
        T_transpose_1[ax0] = A_1[ax0]

If i change target to LLVM or WEBGPU, it works without any error. But i want to run on cuda environment.

Any help to fix this is highly appreciated :slight_smile:

Try add this before relax.build:

    with tvm.target.Target("cuda"):
        mod= tvm.tir.transform.DefaultGPUSchedule()(mod) 

@happyme531 thanks for the response.

still getting the same err:

root@tegra-ubuntu:~/app# python3 play.py 
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 1), dtype="float32"), layer1_weight: R.Tensor((2, 1), dtype="float32"), layer1_bias: R.Tensor((2,), dtype="float32"), layer2_weight: R.Tensor((1, 2), dtype="float32"), layer2_bias: R.Tensor((1,), dtype="float32")) -> R.Tensor((1, 1), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((1, 2), dtype="float32") = R.permute_dims(layer1_weight, axes=None)
            matmul: R.Tensor((1, 2), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 2), dtype="float32") = R.add(matmul, layer1_bias)
            permute_dims1: R.Tensor((2, 1), dtype="float32") = R.permute_dims(layer2_weight, axes=None)
            matmul1: R.Tensor((1, 1), dtype="float32") = R.matmul(add, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 1), dtype="float32") = R.add(matmul1, layer2_bias)
            gv: R.Tensor((1, 1), dtype="float32") = add1
            R.output(gv)
        return gv
------
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 1), dtype="float32"), layer1_weight: R.Tensor((2, 1), dtype="float32"), layer1_bias: R.Tensor((2,), dtype="float32"), layer2_weight: R.Tensor((1, 2), dtype="float32"), layer2_bias: R.Tensor((1,), dtype="float32")) -> R.Tensor((1, 1), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((1, 2), dtype="float32") = R.permute_dims(layer1_weight, axes=None)
            matmul: R.Tensor((1, 2), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 2), dtype="float32") = R.add(matmul, layer1_bias)
            permute_dims1: R.Tensor((2, 1), dtype="float32") = R.permute_dims(layer2_weight, axes=None)
            matmul1: R.Tensor((1, 1), dtype="float32") = R.matmul(add, permute_dims1, out_dtype="void")
            add1: R.Tensor((1, 1), dtype="float32") = R.add(matmul1, layer2_bias)
            gv: R.Tensor((1, 1), dtype="float32") = add1
            R.output(gv)
        return gv
Traceback (most recent call last):
  File "/root/app/play.py", line 47, in <module>
    relax.build(mod, "cuda")
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
  File "/usr/local/lib/python3.10/dist-packages/tvm/driver/build_module.py", line 294, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff5fd0c598]
  [bt] (7) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x428) [0xffff5fd0d0b8]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x278) [0xffff5fd0c598]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1c8) [0xffff5fd0aeac]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13674) [0xffff606e3674]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f13294) [0xffff606e3294]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1f1067c) [0xffff606e067c]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff5f9f36a8]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff6184d050]
  Did you forget to bind?
    Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `T_transpose` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/opt/mlc-llm/3rdparty/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def transpose1(A: T.Buffer((T.int64(1), T.int64(2)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(1)), "float32")):
    T.func_attr({"target": T.target({"arch": "sm_87", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    for ax0 in range(2):
        T_transpose_1 = T.Buffer((T.int64(2),), data=T_transpose.data)
        A_1 = T.Buffer((T.int64(2),), data=A.data)
        T_transpose_1[ax0] = A_1[ax0]

code:

with tvm.target.Target("cuda"):
    print(mod)
    print('------')
    mod = tvm.tir.transform.DefaultGPUSchedule()(mod) 
    print(mod)
    relax.build(mod, "cuda")

Call LegalizeOps before DefaultGPUSchedule and then build

Any more suggestions? I see the same issue for llama-2 7B on vulkan and rocm after applying the mentioned two suggestions. Thanks!