Variable `matmul` is directly accessed by host memory

I’m trying to compile a simple neural network for Vulkan, closely modelled after the example at Quick Start — tvm 0.20.dev0 documentation, and am running into an error. If I change device from tvm.vulkan() to tvm.cuda() or tvm.cpu(), the model compiles, and I get an array of 10 random numbers as the output of the script, as expected. I found the LegalizeOps and DefaultGPUSchedule transformations necessary for CUDA: omitting either (or both) results in the same error message as the one I get for Vulkan (with any combination of these transformations). I spent a fair amount of time searching for a solution, but did not find one. Please help.

Here’s my code:

import tvm
import numpy as np

from tvm import relax
from tvm.relax.frontend import nn


class MLPModel(nn.Module):
    input_dim = 784
    hidden_dim = 256
    output_dim = 10

    def __init__(self):
        super(MLPModel, self).__init__()

        self.fc1 = nn.Linear(MLPModel.input_dim, MLPModel.hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(MLPModel.hidden_dim, MLPModel.output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


model, param_spec = MLPModel().export_tvm(
    spec={"forward": {"x": nn.spec.Tensor((1, MLPModel.input_dim), "float32")}}
)

# tvm.cuda() or tvm.cpu() works
device = tvm.vulkan()
assert device.exist

with tvm.target.Target.from_device(device) as target:
    # The following two transformations are necessary for CUDA, but not for CPU.
    model = tvm.relax.transform.LegalizeOps()(model)
    model = tvm.tir.transform.DefaultGPUSchedule()(model)
    executable = relax.build(model, target)

vm = relax.VirtualMachine(executable, device)

data = np.random.rand(1, MLPModel.input_dim).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
params = [tvm.nd.array(param, device=device) for param in params]
print(vm["forward"](tvm_data, *params).numpy())

And here is the error message:

Traceback (most recent call last):
  File "/home/bence/single-node-model.py", line 36, in <module>
    executable = relax.build(model, target)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bence/git/tvm/python/tvm/relax/vm_build.py", line 353, in build
    return _vmlink(
           ^^^^^^^^
  File "/home/bence/git/tvm/python/tvm/relax/vm_build.py", line 249, in _vmlink
    lib = tvm.build(
          ^^^^^^^^^^
  File "/home/bence/git/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bence/git/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/home/bence/git/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  10: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::__mk_TVM24::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tv
m::Target)#1}>(tvm::__mk_TVM24::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#1}, std:                                                                                                                                                                                                 
:__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  9: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  8: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  7: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  6: tvm::transform::Pass::operator()(tvm::IRModule) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: _ZN3tvm7runtime13PackedFuncObj
  0: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::VerifyMemory()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtim
e::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  Did you forget to bind?
    Variable `permute_dims` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `x` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `matmul` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/home/bence/git/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), permute_dims: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
    T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": Tr
ue, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
    for i1, k in T.grid(256, 784):
        matmul_1 = T.Buffer((T.int64(256),), data=matmul.data)
        if k == 0:
            matmul_1[i1] = T.float32(0.0)
        x_1 = T.Buffer((T.int64(784),), data=x.data)
        permute_dims_1 = T.Buffer((T.int64(200704),), data=permute_dims.data)
        matmul_1[i1] = matmul_1[i1] + x_1[k] * permute_dims_1[k * 256 + i1]

Thanks for reporting the bug. It’s fixed in [TIR] Minor fix for default GPU schedule by Hzfengsy · Pull Request #17706 · apache/tvm · GitHub

Thank you, I verified that #17706 indeed fixes the issue.