"Did you forget to bind? error" when building IRModule (topi depthwise conv2d op)

Hi, I am unable to build a depthwise convolution that uses topi.nn.depthwise_conv2d_nhwc with ansor

I attached the complete script and error below.

my script to build a depthwise convolution for opencl target

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi, relay, rpc, autotvm
from tvm.topi.testing import conv2d_nhwc_python

from tvm.topi.testing import depthwise_conv2d_python_nhwc

from tvm.contrib import utils, ndk
from tvm.contrib import graph_executor
from tvm.relay.op.contrib import clml

from tvm import relay
from tvm import relax

target = tvm.target.Target("opencl", host="llvm -mtriple=aarch64-linux-android")

dtype = "float16"

# input nhwc
N = 1
H = 112
W = 112
C = 32

# kernel RSC
R = 3
S = 3

STRIDE_H = 1
STRIDE_W = 1

PAD_H = 0
PAD_W = 0

strides = (STRIDE_H, STRIDE_W)
padding = (PAD_H, PAD_W)

# output npqk
P = (H + 2 * PAD_H - R) // STRIDE_H + 1
Q = (W + 2 * PAD_W - S) // STRIDE_W + 1

@auto_scheduler.register_workload
def depthwise_conv2d_func(N, H, W, C, R, S, stride=(STRIDE_H,STRIDE_W), padding=(PAD_H,PAD_W)):
    data = te.placeholder((N, H, W, C), name="data", dtype="float16")
    kernel = te.placeholder((R, S, C, 1), name="kernel", dtype="float16")
    dw_conv = topi.nn.depthwise_conv2d_nhwc(data, kernel, stride, padding, dilation=1, kernel_layout='HWOI', out_dtype="float16")
    return [data, kernel, dw_conv]

data, kernel, dw_conv = depthwise_conv2d_func(N, H, W, C, R, S, stride=(STRIDE_H,STRIDE_W), padding=(PAD_H,PAD_W))

sched = te.create_schedule(dw_conv.op)

mod = tvm.lower(sched, [data, kernel, dw_conv], simple_mode=True)
print(mod)

rt_mod = tvm.build(mod, target=target)
print(mod.get_source())

error

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 112, 112, 32), "float16"), kernel: T.Buffer((3, 3, 32, 1), "float16"), DepthwiseConv2d: T.Buffer((1, 110, 110, 32), "float16")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        PaddedInput = T.allocate([401408], "float16", "global")
        PaddedInput_1 = T.Buffer((401408,), "float16", data=PaddedInput)
        for i1, i2, i3 in T.grid(112, 112, 32):
            cse_var_1: T.int32 = i1 * 3584 + i2 * 32 + i3
            data_1 = T.Buffer((401408,), "float16", data=data.data)
            PaddedInput_1[cse_var_1] = data_1[cse_var_1]
        for i, j, c in T.grid(110, 110, 32):
            DepthwiseConv2d_1 = T.Buffer((387200,), "float16", data=DepthwiseConv2d.data)
            DepthwiseConv2d_1[i * 3520 + j * 32 + c] = T.float16(0.0)
            for di, dj in T.grid(3, 3):
                cse_var_4: T.int32 = j * 32
                cse_var_3: T.int32 = dj * 32
                cse_var_2: T.int32 = i * 3520 + cse_var_4 + c
                kernel_1 = T.Buffer((288,), "float16", data=kernel.data)
                DepthwiseConv2d_1[cse_var_2] = DepthwiseConv2d_1[cse_var_2] + PaddedInput_1[i * 3584 + di * 3584 + cse_var_4 + cse_var_3 + c] * kernel_1[di * 96 + cse_var_3 + c]
---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
Cell In[1], line 58
     55 mod = tvm.lower(sched, [data, kernel, dw_conv], simple_mode=True)
     56 print(mod)
---> 58 rt_mod = tvm.build(mod, target=target)
     59 print(mod.get_source())

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/driver/build_module.py:297, in build(inputs, args, target, target_host, runtime, name, binds)
    293     target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
    295 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
--> 297 rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
    299 annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
    301 if not isinstance(target_host, Target):

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/_ffi/_cython/packed_func.pxi:270, in tvm._ffi._cy3.core.FuncCall()

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/_ffi/_cython/packed_func.pxi:259, in tvm._ffi._cy3.core.FuncCall3()

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File ~/FastConvolution/src/FastConvolution/tvm2/python/tvm/_ffi/base.py:481, in raise_last_ffi_error()
    475 # The exception PyObject may contain a large amount of state,
    476 # including all stack frames that may be inspected in a later
    477 # PDB post-mortem.  Therefore, we must make sure to remove the
    478 # underlying PyObject* from the C++ side after we retrieve it.
    479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err

TVMError: Traceback (most recent call last):
  Did you forget to bind?
    Variable `kernel` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `DepthwiseConv2d` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `DepthwiseConv2d` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `DepthwiseConv2d` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `data` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/Users/varunnaw/FastConvolution/src/FastConvolution/tvm2/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def main(data: T.Buffer((1, 112, 112, 32), "float16"), kernel: T.Buffer((3, 3, 32, 1), "float16"), DepthwiseConv2d: T.Buffer((1, 110, 110, 32), "float16")):
    T.func_attr({"from_legacy_te_schedule": T.bool(True), "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
    PaddedInput = T.allocate([401408], "float16", "global")
    PaddedInput_1 = T.Buffer((401408,), "float16", data=PaddedInput)
    for i1, i2, i3 in T.grid(112, 112, 32):
        cse_var_1: T.int32 = i1 * 3584 + i2 * 32 + i3
        data_1 = T.Buffer((401408,), "float16", data=data.data)
        PaddedInput_1[cse_var_1] = data_1[cse_var_1]
    for i, j, c in T.grid(110, 110, 32):
        DepthwiseConv2d_1 = T.Buffer((387200,), "float16", data=DepthwiseConv2d.data)
        DepthwiseConv2d_1[i * 3520 + j * 32 + c] = T.float16(0.0)
        for di, dj in T.grid(3, 3):
            cse_var_4: T.int32 = j * 32
            cse_var_3: T.int32 = dj * 32
            cse_var_2: T.int32 = i * 3520 + cse_var_4 + c
            kernel_1 = T.Buffer((288,), "float16", data=kernel.data)
            DepthwiseConv2d_1[cse_var_2] = DepthwiseConv2d_1[cse_var_2] + PaddedInput_1[i * 3584 + di * 3584 + cse_var_4 + cse_var_3 + c] * kernel_1[di * 96 + cse_var_3 + c]

I don’t know what is meant by binding

I tried this, but no luck:

# Create buffer declarations for device
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, name="data_buf")
kernel_buf = tvm.tir.decl_buffer(kernel.shape, kernel.dtype, name="kernel_buf")
dw_conv_buf = tvm.tir.decl_buffer(dw_conv.shape, dw_conv.dtype, name="dw_conv_buf")
# Bind the buffers to the computation
binds = {data: data_buf, kernel: kernel_buf, dw_conv: dw_conv_buf}
rt_mod = tvm.build(mod, target=target, binds=binds)

same error

we are moving towards metaschedule autotuning, with some examples here End-to-End Optimize Model — tvm 0.18.dev0 documentation

Thank you for you response.

how do I get the target source code for this IRModule after the metaschedule pass? Also, if is a target like opencl or cuda, how do I get the global work dimensions and work group dimensions (or block dimensions and grid dimensions for cuda)

Other questions Is metaschedule designed to find a schedule for an arbitrary tensor expression just as ansor would?

Also, I was having trouble finding even just one suitable schedule with ansor for a 2d convolution operator I specified in te. I used topi as well: https://gist.github.com/poltomo/d3dba060e375971f612cee1153c865d7 Is this an Ansor problem?