Error "direct host side access to device memory is detected in ... , did you forget to bind?" when compile a onnx model with target cuda

I was training a model of resnet50 in pytorch and store the checkpoint as onnx format.

ok thanks. I’ll take a look.

Can you share pytorch code for resnet50, or did you use one from torchvision?

sorry for not mention that there is seblock in resnet50. here is pytorch code for resnet50 https://github.com/yqwang/tvm_llvm_json/blob/master/resnet.py

hmm, never heard of SELayer. I guess this is what is causing the error, because this is not a standard layer.
TVM and NNVM are tested mostly on standard imagenet models. If you try something new, weird errors might arise.

Is SELayer the same as this one?

yes,

import torch
import torch.nn as nn
import torch.nn.functional as F


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = F.avg_pool2d(x, kernel_size=x.size()[2:]).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

Ok, I can reproduce your error. I think this is an interesting case.

Let me dig into a bit.

I figured out the cause of your error. If you change this line to

if tag.is_injective(OP.tag):

it should work.

Thanks, it looks well now, the problem is sovled, thanks again.

Hi @masahi

I am facing the similar error while compiling the onnx model
Note : The Error does’t come when using opt_level=0 while compiling the model using NNVM,
but get it only while using opt_level=1 or opt_level=2 or opt_level=3
Code:
with nnvm.compiler.build_config(opt_level=1):
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params=params)
Error:
Traceback (most recent call last):
File “sface_nchw_trail1.py”, line 64, in
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params=params)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/compiler/build_module.py”, line 306, in build
graph = graph.apply(“GraphCompile”)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/graph.py”, line 234, in apply
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/_base.py”, line 75, in check_call
raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMCall CFunc Error:
Traceback (most recent call last):
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/_ffi/_ctypes/function.py”, line 55, in cfun
rv = local_pyfunc(*pyargs)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/compiler/build_module.py”, line 124, in _build
return tvm.build(funcs, target=target, target_host=target_host)
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/build_module.py”, line 586, in build
fhost, mdev = _build_for_device(flist, tar, target_host)
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/build_module.py”, line 415, in _build_for_device
“Did you forget to bind?” % func.name)
ValueError: Direct host side access to device memory is detected in fuse_matmul_relu. Did you forget to bind?

I don’t think we have a GPU schedule for matmul. So it tries to call CPU matmul schedule, which gives the error you are seeing.

Thanks @masahi
I tried to run on opencl as well,

The Error doesn’t come when using opt_level=0 while compiling the model using NNVM,
but get it only while using opt_level=1 or opt_level=2 or opt_level=3

Hi @masahi

Is there any plan for writing the GPU schedule for matmal.
Is there any way I can work on that?? If so could you please guide me for the same?

I don’t think there is a plan to add a GPU schedule for matmul. But there is the batched_matmul schedule for GPU. You can use that from our onnx frontend.

Thanks @masahi , could you guide me towards writing the gpu schedule for matmul as it might be there for cpu ?

Hi @masahi @srkreddy1238 @FrozenGene @tqchen
facing same error mentioned above.
what line did you change so that it worked out for

fuse_reshape_broadcast_mul_conv2d_broadcast_mul_broadcast_add_elemwise_add.
 Did you forget to bind?

I’m getting the below error:

nnvm._base.NNVMError: ValueError: Direct host side access to device memory is detected in fuse_matmul_relu. Did you forget to bind?

Did not understand how to use batched_matmul for the same.
Any help here?

Hi @masahi @Hkathuria @yqwang I am facing this same issue for metal GPU.

Error: ValueError: Direct host side access to device memory is detected in addone. Did you forget to bind?

python script is:

import tvm
import os

def prepare_test_libs(base_path):
n = tvm.var(“n”)
A = tvm.placeholder((n,), name=‘A’)
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name=‘B’)
s = tvm.create_schedule(B.op)

Compile library as dynamic library

fadd_dylib = tvm.build(s, [A, B], "metal", name="addone")
dylib_path = os.path.join(base_path, "test_addone_dll.dylib")
fadd_dylib.export_library(dylib_path)

if name == “main”:

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
prepare_test_libs(os.path.join(curr_path, "./lib"))   

I am building it for metal device.

you need to add proper schedule such as binding axis to threads

Hi @vinx13 can you please suggest some sample for this?

see https://docs.tvm.ai/tutorials/tensor_expr_get_started.html#schedule-the-computation
something like
s[C].bind(bx, tvm.thread_axis(“blockIdx.x”))
s[C].bind(tx, tvm.thread_axis(“threadIdx.x”))