Quantization and 3D convolution

Hi all,

I would like to contribute to this project by implementing 8-bit quantization for 3d convolution. Currently my implementation works fine without auto-tuning. It is quite similar to what is happening in 2D:

  1. Reshape the input data and the kernel such as the convolution computation can be vectorized

  2. Perform the convolution computation in a vectorized fashion via dp4a.

  3. Reshape the output

The 8-bit convolution outputs are relatively close to the standard convolution one.

The auto-tuning step runs smoothly (it takes more time to run) and it outputs a log file with the optimal configuration for the 3d convolution (conv3d_ncdhw_int8).

However, during the compilation phase, I sometimes encounter the following error:

[12:11:45] /usr/tvm/src/tir/transforms/loop_partition.cc:548: Cannot prove: ((((((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx
.z, 2)) - 1) - (29 - (blockIdx.z*4))) + 1) >= 0), when generating the post doubt loop
Traceback (most recent call last):                                                                                                    [0/1634]
  File "tune_relay_cuda_int8.py", line 508, in <module>
    tune_and_evaluate(tuning_option)
  File "tune_relay_cuda_int8.py", line 409, in tune_and_evaluate
    graph, lib, params = relay.build_module.build(mod, target=target, params=params)
  File "/usr/tvm/python/tvm/relay/build_module.py", line 260, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/usr/tvm/python/tvm/relay/build_module.py", line 127, in build
    self._build(mod, target, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /usr/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::runtime::String, tvm::IRModule, void, void> const&, tvm::Target const&)+0x83c) [0
x7fd6f772267c]
  [bt] (7) /usr/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)+0x2c7) [0x7fd6f772
1397]
  [bt] (6) /usr/tvm/build/libtvm.so(tvm::SplitDevHostFuncs(tvm::IRModule, tvm::Target const&, tvm::Target const&, tvm::transform::PassContext 
const&)+0x488) [0x7fd6f771fff8]
  [bt] (5) /usr/tvm/build/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule) const+0x6a) [0x7fd6f71d8e7a]
  [bt] (4) /usr/tvm/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x40e)
 [0x7fd6f7241d1e]
  [bt] (3) /usr/tvm/build/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1e2)
 [0x7fd6f723fe52]
  [bt] (2) /usr/tvm/build/libtvm.so(+0x8d347c) [0x7fd6f74d147c]
  [bt] (1) /usr/tvm/build/libtvm.so(tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)+0x2d19) [0x7fd6f74ce7a9]
  [bt] (0) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x61) [0x7fd6f7138f91]
  File "/usr/tvm/src/tir/transforms/make_packed_api.cc", line 210
TVMError: Not all Vars are passed in api_args:  'threadIdx.z'  is not bound to any variables

Depending on the optimization that has been found by the auto-tuner, this error may or may not occur. For instance, by modifying the log that was produced during the auto-tuning, I am able to make a invalid configuration actually work.

Invalid configuration

{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32", "conv3d_NCDHWc_int8.cuda", [["TENSOR", [1, 128, 18, 56, 56], "int8"], ["TENSOR", [128, 128, 3, 3, 3], "int8"], [1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1], "NCDHW", "int32"], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev0"}

Valid configuration

{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32", "conv3d_NCDHWc_int8.cuda", [["TENSOR", [1, 128, 18, 56, 56], "int8"], ["TENSOR", [128, 128, 3, 3, 3], "int8"], [1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1], "NCDHW", "int32"], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 1]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev0"}

I am not sure how to solve this problem. What would you advice me ?

2 Likes

I implemented the conv3d with int8 as following:

I create the file python/tvm/topi/cuda/conv3d_int8.py which implement the operation itself.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""Int8 conv3d in NCDHWc layout"""
import tvm
from tvm import te
from tvm import autotvm

from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv3d import unpack_NCDHWc_to_ncdhw
from ..nn.util import get_pad_tuple3d
from ..util import simplify, get_const_tuple, traverse_inline


def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"):
    """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype"""
    assert data.dtype in ("int8", "uint8")
    assert kernel.dtype in ("int8", "uint8")
    assert data.dtype == kernel.dtype
    packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, "NCDHW", out_dtype)
    return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype)


def schedule_conv3d_ncdhw_int8(outs):
    """Create schedule for tensors"""
    return schedule_conv3d_NCDHWc_int8(outs)


@autotvm.register_topi_compute("conv3d_NCDHWc_int8.cuda")
def conv3d_NCDHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
    """Convolution operator in NCDHW[x]c layout for int8."""

    # print("conv3d_NCDHWc_int8")

    assert layout in ["NCDHW", "NCDHW4c"]

    ic_block_factor = 4
    oc_block_factor = 4

    pre_computed = len(kernel.shape) == 7
    if not pre_computed:
        batch, channels, depth, height, width = get_const_tuple(data.shape)
        assert (
            channels % ic_block_factor == 0
        ), "Number of input channels should be multiple of {}".format(ic_block_factor)
        packed_data = te.compute(
            (batch, channels // ic_block_factor, depth, height, width, ic_block_factor),
            lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, w],
            name="packed_data",
        )

        out_channels, in_channels, kernel_d, kernel_h, kernel_w = get_const_tuple(kernel.shape)
        assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format(
            oc_block_factor
        )
        packed_kernel = te.compute(
            (
                out_channels // oc_block_factor,
                in_channels // ic_block_factor,
                kernel_d,
                kernel_h,
                kernel_w,
                oc_block_factor,
                ic_block_factor,
            ),
            lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[
                oc_chunk * oc_block_factor + oc_block,
                ic_chunk * ic_block_factor + ic_block,
                kd,
                kh,
                kw,
            ],
            name="packed_kernel",
        )

    else:
        packed_data = data
        packed_kernel = kernel

    batch, ic_chunk, in_depth, in_height, in_width, ic_block = get_const_tuple(packed_data.shape)
    oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
        packed_kernel.shape
    )
    assert isinstance(stride, int) or len(stride) == 3
    assert isinstance(dilation, int) or len(dilation) == 3

    if isinstance(stride, int):
        stride_d = stride_h = stride_w = stride
    else:
        stride_d, stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_d = dilation_h = dilation_w = dilation
    else:
        dilation_d, dilation_h, dilation_w = dilation

    # # compute the output shape

    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
        padding, (kernel_d, kernel_h, kernel_w)
    )
    # out_channel = num_filter
    out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1
    out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1
    out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1

    oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block)
    # compute graph
    pad_before = [0, 0, pad_front, pad_top, pad_left, 0]
    pad_after = [0, 0, pad_back, pad_down, pad_right, 0]
    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
    icb = te.reduce_axis((0, ic_block), name="ic_block")
    rz = te.reduce_axis((0, kernel_d), name="rz")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")

    conv = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum(
            pad_data[
                nn,
                icc,
                zz * stride_d + rz * dilation_d,
                yy * stride_h + ry * dilation_h,
                xx * stride_w + rx * dilation_w,
                icb,
            ].astype("int32")
            * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, icb].astype("int32"),
            axis=[icc, rz, ry, rx, icb],
        ),
    )


    output = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, xx, oc_block].astype(
            out_dtype
        ),
        tag="conv3d_NCDHWc_int8",
    )

    # num flop
    num_flop = (
        batch
        * oc_chunk
        * oc_block
        * out_height
        * out_width
        * ic_chunk
        * ic_block
        * kernel_d
        * kernel_h
        * kernel_w
        * 2
    )
    cfg.add_flop(num_flop)

    return output


_dp4a = dp4a("shared", "shared", "local")


@autotvm.register_topi_schedule("conv3d_NCDHWc_int8.cuda")
def schedule_conv3d_NCDHWc_int8(cfg, outs):
    """Schedule conv3d int8 NCDHWc template"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "conv3d_NCDHWc_int8":
            _schedule_conv3d_NCDHWc_int8(cfg, s, op.output(0), "NCDHW", "conv3d_NCDHWc_int8.cuda")

    traverse_inline(s, outs[0].op, _callback)
    return s


def _schedule_conv3d_NCDHWc_int8(cfg, s, output, layout, workload_name):

    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)
    if pad_data != packed_data:
        s[pad_data].compute_inline()

    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 6:
        n, f, d, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, d, y, x = s[output].op.axis

    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_d", cfg.axis(d), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    kernel_scope, n = s[output].split(n, nparts=1)

    # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi)

    bf = s[output].fuse(n, bf)

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(bd, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vd, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tf, te.thread_axis("threadIdx.z"))
        s[output].bind(td, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_f"].size[2]
        n_ty = cfg["tile_d"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, d, y, x, c = s[conv].op.axis
    rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis

    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    rdo, rdi = cfg["tile_rd"].apply(s, conv, rd)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
    s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # # cooperative fetching
    for load in [AA, WW]:

        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s

In the file python/tvm/relay/op/strategy/cuda.py, I linked the new implementation to conv3d_strategy_cuda as following:


@conv3d_strategy.register(["cuda", "gpu"])
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
    """conv3d cuda strategy"""
    strategy = _op.OpStrategy()
    data, kernel = inputs
    layout = attrs.data_layout
    kernel_layout = attrs.kernel_layout
    _, stride_h, stride_w = attrs.get_int_tuple("strides")
    _, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
    assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
    if layout == "NCDHW":

        if attrs.groups == 1:
            assert kernel_layout == "OIDHW"
            if data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8"):
                assert data.dtype == kernel.dtype

                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_int8),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_NCDHWc_int8),
                    name="conv3d_ncdhw_int8.cuda",
                )
            else:
                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
                    name="conv3d_ncdhw.cuda",
                    plevel=10,
                )
            _, _, _, kh, kw = get_const_tuple(kernel.shape)
            if (
                2 < kh < 8
                and 2 < kw < 8
                and kh == kw
                and stride_h == 1
                and stride_w == 1
                and dilation_h == 1
                and dilation_w == 1
            ):
                strategy.add_implementation(
                    wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
                    name="conv3d_ncdhw_winograd.cuda",
                    plevel=5,
                )
       

    else:  # layout == "NDHWC":
        strategy.add_implementation(
            wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
            wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
            name="conv3d_ndhwc.cuda",
            plevel=10,
        )
        N, _, _, _, _ = get_const_tuple(data.shape)
        _, _, _, CI, CO = get_const_tuple(kernel.shape)
        if target.kind.name == "cuda":
            if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                if (
                    (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
                    or (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
                    or (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
                ):
                    strategy.add_implementation(
                        wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
                        wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
                        name="conv3d_ndhwc_tensorcore.cuda",
                        plevel=20,
                    )

    if target.kind.name == "cuda" and "cudnn" in target.libs:
        strategy.add_implementation(
            wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
            wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
            name="conv3d_cudnn.cuda",
            plevel=25,
        )
    return strategy

In the file python/tvm/relay/quantize/_annotate.py, I defined new annotation such as:

@register_annotate_function("nn.contrib_conv3d_NCDHWc")
def conv3d_ncdhwc_rewrite(ref_call, new_args, ctx):
    warnings.warn(
        "NCDHWc layout Conv3D detected, please use a lower "
        "optimization level before applying the quantization "
        "pass as quantization will have no effect here..."
    )


@register_annotate_function("nn.conv3d")
def conv3d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for conv2d. Lhs of conv will be quantized to
    input field, and rhs of conv will be quantized to weight field.
    Output would be in activation field"""

    if quantize_context().check_to_skip(ref_call):

        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:

        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

I also registed a new partition function in python/tvm/relay/quantize/_partition.py

@register_partition_function("nn.conv3d")
def conv3d_partition_function(ref_call, new_args, ctx):
    """Rewrite function for conv3d for partition"""
    data_cond, data = partition_expr_check(new_args[0])
    kernel_cond, kernel = partition_expr_check(new_args[1])

    assert not kernel_cond
    if data_cond:
        data = new_args[0].realize()
    ret = _forward_op(ref_call, [data, kernel])
    return QPartitionExpr(ret)

I also implemented Conv3dRealize:

Expr Conv3dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
  const QConfig& cfg = QConfig::Current();
  CHECK_EQ(new_args.size(), 2);
  if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
    return Expr(nullptr);
  }
  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
  CHECK(lhs);
  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
  CHECK(rhs);

  Expr ldata = lhs->data;
  if (lhs->dtype != cfg->dtype_input) {
    ldata = Cast(ldata, cfg->dtype_input);
  }
  Expr rdata = Cast(rhs->data, cfg->dtype_weight);

  const auto ref_attrs = ref_call->attrs.as<Conv3DAttrs>();
  auto attrs = make_object<Conv3DAttrs>();
  *attrs = *ref_attrs;
  DataType out_dtype = cfg->dtype_activation;
  attrs->out_dtype = out_dtype;

  Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
  Expr dom_scale = FoldConstantOpt(mul);
  return QRealizeIntExpr(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv3dRealize);

As mentionned previously, the int8 based 3d convolution alone gives the right result and can be optimized by the auto-tuning module of tvm. However, during the compilation phase, I “often” encounter the error mentioned above. I figured out that depending on the optimization found by the automatic tuner, the above mentioned error may or may not occur. I don’t know how to solve this issue.

1 Like

Hello @OValery16, I believe the issue you are encountering is that you are calling te.thread_axis("threadIdx.z") multiple times. Instead, can you try creating the thread axis once with thread_z = te.thread_axis("threadIdx.y") and then use it like so: s[output].bind(s[output].fuse(tf, td), thread_z). I think you’ll also have to do this for threadIdx.x and threadIdx.y.

Hi @tkonolige,

Thanks a lot for your help.

Unfortunately, your fix didn t solve the problem.

I am a bit confused because my implementation is very closed to the one for conv2d_NCHWc_int8

def _schedule_conv2d_NCHWc_int8(cfg, s, output):
    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)

    if pad_data != packed_data:
        s[pad_data].compute_inline()

    # create cache stage
    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 5:
        n, f, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, y, x = s[output].op.axis

    cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    # this is the scope to attach global config inside this kernel
    kernel_scope, n = s[output].split(n, nparts=1)

    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
    s[output].bind(bn, te.thread_axis("blockIdx.z"))
    s[output].bind(bf, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vn, te.thread_axis("vthread"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tn, te.thread_axis("threadIdx.z"))
        s[output].bind(tf, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_n"].size[2]
        n_ty = cfg["tile_f"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile and bind reduction axes
    n, f, y, x, c = s[conv].op.axis

    rc, ry, rx, rc_block = s[conv].op.reduce_axis
    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)

    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # cooperative fetching
    for load in [AA, WW]:
        print(s[load].op.axis)
        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)

        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # double buffer
    cfg.define_knob("AA_double_buffer", [0, 1])
    cfg.define_knob("WW_double_buffer", [0, 1])
    if cfg["AA_double_buffer"].val:
        s[AA].double_buffer()
    if cfg["WW_double_buffer"].val:
        s[WW].double_buffer()

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s
1 Like

Could you print out the lowered code? You can use tvm.lower(s, args) where s is the schedule. Also, if you provide a minimal example to run, I can take a look at it.

@tkonolige Thanks a lot for your help.

Regarding the tvm.lower(s, args), you can find below the generated code .

Before tuning, I got:

#[version = "0.0.5"]
primfn(A_1: handle, W_1: handle, output_unpack_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {output_unpack: Buffer(output_unpack_2: Pointer(int32), int32, [1, 128, 18, 56, 56], []),
             W: Buffer(W_2: Pointer(int8), int8, [128, 128, 3, 3, 3], []),
             A: Buffer(A_2: Pointer(int8), int8, [1, 128, 18, 56, 56], [])}
  buffer_map = {A_1: A, W_1: W, output_unpack_1: output_unpack} {
  attr [packed_data: Pointer(int8)] "storage_scope" = "global";
  allocate(packed_data, int8, [7225344]);
  attr [packed_kernel: Pointer(int8)] "storage_scope" = "global";
  allocate(packed_kernel, int8, [442368]) {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024;
    for (n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer: int32, 0, 28) {
      if ((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)) < 1806336) {
        if ((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*262144) + (blockIdx.x*1024)) + threadIdx.x) < 7225344) {
          packed_data[(((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*262144) + (blockIdx.x*1024)) + threadIdx.x)] = (int8*)A_2[(((floordiv((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)), 56448)*225792) + (floormod(threadIdx.x, 4)*56448)) + floormod((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)), 56448))]
        }
      }
    }
    attr [IterVar(blockIdx.x_1: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
    attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024;
    for (oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer: int32, 0, 2) {
      if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)) < 27648) {
        if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*65536) + (blockIdx.x_1*256)) + floordiv(threadIdx.x_1, 4)) < 110592) {
          if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*262144) + (blockIdx.x_1*1024)) + threadIdx.x_1) < 442368) {
            packed_kernel[(((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*262144) + (blockIdx.x_1*1024)) + threadIdx.x_1)] = (int8*)W_2[(((((floordiv((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 864)*13824) + (floordiv(floormod(threadIdx.x_1, 16), 4)*3456)) + (floordiv(floormod((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 864), 27)*108)) + (floormod(threadIdx.x_1, 4)*27)) + floormod((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 27))]
          }
        }
      }
    }
    attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 128;
    attr [compute: Pointer(int32)] "storage_scope" = "local";
    allocate(compute, int32, [1]);
    attr [pad_data.shared: Pointer(int8)] "storage_scope" = "shared";
    allocate(pad_data.shared, int8x4, [1]);
    attr [packed_kernel.shared: Pointer(int8)] "storage_scope" = "shared";
    allocate(packed_kernel.shared, int8x4, [1]);
    attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 18;
    attr [IterVar(blockIdx.x_2: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 3136;
    attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
    attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
    attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1 {
      compute[0] = 0
      for (ic_chunk.outer: int32, 0, 32) {
        for (rz.outer: int32, 0, 3) {
          for (ry.outer: int32, 0, 3) {
            for (rx.outer: int32, 0, 3) {
              attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
              attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
              attr [IterVar(threadIdx.x_3: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1;
              pad_data.shared[ramp(0, 1, 4)] = @tir.if_then_else(((((((1 <= (blockIdx.y + rz.outer)) && ((blockIdx.y + rz.outer) < 19)) && (1 <= (floordiv(blockIdx.x_2, 56) + ry.outer))) && ((floordiv(blockIdx.x_2, 56) + ry.outer) < 57)) && (1 <= (rx.outer + floormod(blockIdx.x_2, 56)))) && ((rx.outer + floormod(blockIdx.x_2, 56)) < 57)), (int8x4*)packed_data[ramp((((((((ic_chunk.outer*225792) + (blockIdx.y*12544)) + (rz.outer*12544)) + (ry.outer*224)) + (blockIdx.x_2*4)) + (rx.outer*4)) - 12772), 1, 4)], broadcast(0i8, 4), dtype=int8x4)
              attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
              attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
              attr [IterVar(threadIdx.x_4: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1;
              packed_kernel.shared[ramp(0, 1, 4)] = (int8x4*)packed_kernel[ramp(((((((floordiv(blockIdx.z, 4)*13824) + (ic_chunk.outer*432)) + (rz.outer*144)) + (ry.outer*48)) + (rx.outer*16)) + (floormod(blockIdx.z, 4)*4)), 1, 4)]
              compute[0] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp(0, 1, 4)], (int8x4*)packed_kernel.shared[ramp(0, 1, 4)], (int32*)compute[0], dtype=int32)
            }
          }
        }
      }
      output_unpack_2[(((blockIdx.z*56448) + (blockIdx.y*3136)) + blockIdx.x_2)] = (int32*)compute[0]
    }
  }
}

#[metadata]
{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "Map", 
      "keys": [
        "IntImm"
      ], 
      "data": [2]
    }, 
    {
      "type_key": "Array", 
      "data": [3]
    }, 
    {
      "type_key": "IntImm", 
      "attrs": {
        "dtype": "bool", 
        "value": "1"
      }
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.8.dev0"}
}

After tuning, I got:

#[version = "0.0.5"]
primfn(A_1: handle, W_1: handle, output_unpack_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {output_unpack: Buffer(output_unpack_2: Pointer(int32), int32, [1, 128, 18, 56, 56], []),
             W: Buffer(W_2: Pointer(int8), int8, [128, 128, 3, 3, 3], []),
             A: Buffer(A_2: Pointer(int8), int8, [1, 128, 18, 56, 56], [])}
  buffer_map = {A_1: A, W_1: W, output_unpack_1: output_unpack} {
  attr [packed_data: Pointer(int8)] "storage_scope" = "global";
  allocate(packed_data, int8, [7225344]);
  attr [packed_kernel: Pointer(int8)] "storage_scope" = "global";
  allocate(packed_kernel, int8, [442368]) {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024;
    for (n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer: int32, 0, 28) {
      if ((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)) < 1806336) {
        if ((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*262144) + (blockIdx.x*1024)) + threadIdx.x) < 7225344) {
          packed_data[(((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*262144) + (blockIdx.x*1024)) + threadIdx.x)] = (int8*)A_2[(((floordiv((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)), 56448)*225792) + (floormod(threadIdx.x, 4)*56448)) + floormod((((n.c.fused.d.fused.h.fused.w.fused.vc.fused.outer*65536) + (blockIdx.x*256)) + floordiv(threadIdx.x, 4)), 56448))]
        }
      }
    }
    attr [IterVar(blockIdx.x_1: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
    attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024;
    for (oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer: int32, 0, 2) {
      if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)) < 27648) {
        if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*65536) + (blockIdx.x_1*256)) + floordiv(threadIdx.x_1, 4)) < 110592) {
          if ((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*262144) + (blockIdx.x_1*1024)) + threadIdx.x_1) < 442368) {
            packed_kernel[(((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*262144) + (blockIdx.x_1*1024)) + threadIdx.x_1)] = (int8*)W_2[(((((floordiv((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 864)*13824) + (floordiv(floormod(threadIdx.x_1, 16), 4)*3456)) + (floordiv(floormod((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 864), 27)*108)) + (floormod(threadIdx.x_1, 4)*27)) + floormod((((oc_chunk.ic_chunk.fused.kd.fused.kh.fused.kw.fused.oc_block.fused.ic_block.fused.outer*16384) + (blockIdx.x_1*64)) + floordiv(threadIdx.x_1, 16)), 27))]
          }
        }
      }
    }
    attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 8;
    attr [compute: Pointer(int32)] "storage_scope" = "local";
    allocate(compute, int32, [(((floordiv(((threadIdx.z: int32*2) + 1), 4)*32) + 32) - (floordiv(threadIdx.z, 2)*32))]);
    attr [pad_data.shared: Pointer(int8)] "storage_scope" = "shared";
    allocate(pad_data.shared, int8x4, [56]);
    attr [packed_kernel.shared: Pointer(int8)] "storage_scope" = "shared";
    allocate(packed_kernel.shared, int8x4, [28]);
    attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 9;
    attr [IterVar(blockIdx.x_2: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 112;
    attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 8;
    attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 7;
    attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1 {
      for (oc_chunk.init: int32, 0, ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2))) {
        for (zz.init: int32, 0, 2) "unroll" {
          for (yy.init: int32, 0, 2) "unroll" {
            for (oc_block.init: int32, 0, 4) "unroll" {
              compute[((((oc_chunk.init*16) + (zz.init*8)) + (yy.init*4)) + oc_block.init)] = 0
              compute[(((((((floordiv(((threadIdx.z*2) + 1), 4)*16) + (oc_chunk.init*16)) + (zz.init*8)) + (yy.init*4)) + oc_block.init) + 16) - (floordiv(threadIdx.z, 2)*16))] = 0
            }
          }
        }
      }
      for (rz.outer: int32, 0, 3) {
        for (ry.outer: int32, 0, 3) {
          for (ic_chunk.outer: int32, 0, 32) {
            for (rx.outer: int32, 0, 3) {
              attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 8;
              attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 7;
              attr [IterVar(threadIdx.x_3: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1;
              pad_data.shared[ramp(((threadIdx.z_1*28) + (threadIdx.y_1*4)), 1, 4)] = @tir.if_then_else(((((((1 <= (((blockIdx.y*2) + floordiv(((threadIdx.z_1*7) + threadIdx.y_1), 28)) + rz.outer)) && ((((blockIdx.y*2) + floordiv(((threadIdx.z_1*7) + threadIdx.y_1), 28)) + rz.outer) < 19)) && (1 <= (((floordiv(blockIdx.x_2, 28)*14) + floordiv(floormod(((threadIdx.z_1*7) + threadIdx.y_1), 28), 2)) + ry.outer))) && ((((floordiv(blockIdx.x_2, 28)*14) + floordiv(floormod(((threadIdx.z_1*7) + threadIdx.y_1), 28), 2)) + ry.outer) < 57)) && (1 <= (((floormod(blockIdx.x_2, 28)*2) + rx.outer) + floormod(((threadIdx.z_1*7) + threadIdx.y_1), 2)))) && ((((floormod(blockIdx.x_2, 28)*2) + rx.outer) + floormod(((threadIdx.z_1*7) + threadIdx.y_1), 2)) < 57)), (int8x4*)packed_data[ramp((((((((((((ic_chunk.outer*225792) + (blockIdx.y*25088)) + (floordiv(((threadIdx.z_1*7) + threadIdx.y_1), 28)*12544)) + (rz.outer*12544)) + (floordiv(blockIdx.x_2, 28)*3136)) + (floordiv(floormod(((threadIdx.z_1*7) + threadIdx.y_1), 28), 2)*224)) + (ry.outer*224)) + (floormod(blockIdx.x_2, 28)*8)) + (rx.outer*4)) + (floormod(((threadIdx.z_1*7) + threadIdx.y_1), 2)*4)) - 12772), 1, 4)], broadcast(0i8, 4), dtype=int8x4)
              attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 8;
              attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 7;
              attr [IterVar(threadIdx.x_4: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1;
              if (((threadIdx.z_2*7) + threadIdx.y_2) < 28) {
                if (threadIdx.z_2 < 4) {
                  if (((blockIdx.z*4) + floordiv(((threadIdx.z_2*7) + threadIdx.y_2), 4)) < 32) {
                    packed_kernel.shared[ramp(((threadIdx.z_2*28) + (threadIdx.y_2*4)), 1, 4)] = (int8x4*)packed_kernel[ramp((((((((blockIdx.z*55296) + (floordiv(((threadIdx.z_2*7) + threadIdx.y_2), 4)*13824)) + (ic_chunk.outer*432)) + (rz.outer*144)) + (ry.outer*48)) + (rx.outer*16)) + (floormod(((threadIdx.z_2*7) + threadIdx.y_2), 4)*4)), 1, 4)]
                  }
                }
              }
              for (oc_chunk: int32, 0, min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))) {
                for (zz: int32, 0, 2) "unroll" {
                  for (yy: int32, 0, 2) "unroll" {
                    for (oc_block: int32, 0, 4) "unroll" {
                      compute[((((oc_chunk*16) + (zz*8)) + (yy*4)) + oc_block)] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp((((zz*112) + (threadIdx.y*16)) + (yy*8)), 1, 4)], (int8x4*)packed_kernel.shared[ramp((((floordiv(threadIdx.z, 2)*16) + (oc_chunk*16)) + (oc_block*4)), 1, 4)], (int32*)compute[((((oc_chunk*16) + (zz*8)) + (yy*4)) + oc_block)], dtype=int32)
                      compute[(((((((floordiv(((threadIdx.z*2) + 1), 4)*16) + (oc_chunk*16)) + (zz*8)) + (yy*4)) + oc_block) + 16) - (floordiv(threadIdx.z, 2)*16))] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp(((((zz*112) + (threadIdx.y*16)) + (yy*8)) + 4), 1, 4)], (int8x4*)packed_kernel.shared[ramp((((floordiv(threadIdx.z, 2)*16) + (oc_chunk*16)) + (oc_block*4)), 1, 4)], (int32*)compute[(((((((floordiv(((threadIdx.z*2) + 1), 4)*16) + (oc_chunk*16)) + (zz*8)) + (yy*4)) + oc_block) + 16) - (floordiv(threadIdx.z, 2)*16))], dtype=int32)
                    }
                  }
                }
              }
              for (oc_chunk_1: int32, 0, (max(((((blockIdx.z*4) + floordiv(((threadIdx.z*2) + 1), 4)) - floordiv(threadIdx.z, 2)) - 29), -1) + 1)) {
                for (zz_1: int32, 0, 2) "unroll" {
                  for (yy_1: int32, 0, 2) "unroll" {
                    for (oc_block_1: int32, 0, 4) "unroll" {
                      if (((((blockIdx.z*4) + floordiv(threadIdx.z, 2)) + min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))) + oc_chunk_1) < 32) {
                        compute[(((((min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16) + (oc_chunk_1*16)) + (zz_1*8)) + (yy_1*4)) + oc_block_1)] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp((((zz_1*112) + (threadIdx.y*16)) + (yy_1*8)), 1, 4)], (int8x4*)packed_kernel.shared[ramp(((((floordiv(threadIdx.z, 2)*16) + (min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16)) + (oc_chunk_1*16)) + (oc_block_1*4)), 1, 4)], (int32*)compute[(((((min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16) + (oc_chunk_1*16)) + (zz_1*8)) + (yy_1*4)) + oc_block_1)], dtype=int32)
                        compute[((((((((floordiv(((threadIdx.z*2) + 1), 4)*16) + (min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16)) + (oc_chunk_1*16)) + (zz_1*8)) + (yy_1*4)) + oc_block_1) + 16) - (floordiv(threadIdx.z, 2)*16))] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp(((((zz_1*112) + (threadIdx.y*16)) + (yy_1*8)) + 4), 1, 4)], (int8x4*)packed_kernel.shared[ramp(((((floordiv(threadIdx.z, 2)*16) + (min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16)) + (oc_chunk_1*16)) + (oc_block_1*4)), 1, 4)], (int32*)compute[((((((((floordiv(((threadIdx.z*2) + 1), 4)*16) + (min((29 - (blockIdx.z*4)), ((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)))*16)) + (oc_chunk_1*16)) + (zz_1*8)) + (yy_1*4)) + oc_block_1) + 16) - (floordiv(threadIdx.z, 2)*16))], dtype=int32)
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
      for (c.inner.inner.inner: int32, 0, 2) "unroll" {
        for (z.inner.inner.inner: int32, 0, 2) "unroll" {
          for (h.inner.inner.inner: int32, 0, 2) "unroll" {
            output_unpack_2[(((((((((blockIdx.z*903168) + (threadIdx.z*112896)) + (c.inner.inner.inner*56448)) + (blockIdx.y*6272)) + (z.inner.inner.inner*3136)) + (floordiv(blockIdx.x_2, 28)*784)) + (threadIdx.y*112)) + (h.inner.inner.inner*56)) + (floormod(blockIdx.x_2, 28)*2))] = (int32*)compute[(((((floordiv(((threadIdx.z*2) + c.inner.inner.inner), 4)*16) + (z.inner.inner.inner*8)) + (h.inner.inner.inner*4)) + floormod(((threadIdx.z*2) + c.inner.inner.inner), 4)) - (floordiv(threadIdx.z, 2)*16))]
            output_unpack_2[((((((((((blockIdx.z*903168) + (threadIdx.z*112896)) + (c.inner.inner.inner*56448)) + (blockIdx.y*6272)) + (z.inner.inner.inner*3136)) + (floordiv(blockIdx.x_2, 28)*784)) + (threadIdx.y*112)) + (h.inner.inner.inner*56)) + (floormod(blockIdx.x_2, 28)*2)) + 1)] = (int32*)compute[(((((((floordiv(((threadIdx.z*2) + c.inner.inner.inner), 4)*16) + (floordiv(((threadIdx.z*2) + 1), 4)*16)) + (z.inner.inner.inner*8)) + (h.inner.inner.inner*4)) + floormod(((threadIdx.z*2) + c.inner.inner.inner), 4)) + 16) - (floordiv(threadIdx.z, 2)*32))]
          }
        }
      }
    }
  }
}

#[metadata]
{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "Map", 
      "keys": [
        "IntImm"
      ], 
      "data": [2]
    }, 
    {
      "type_key": "Array", 
      "data": [3]
    }, 
    {
      "type_key": "IntImm", 
      "attrs": {
        "dtype": "bool", 
        "value": "1"
      }
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.8.dev0"}
}
1 Like

I also wrote a a minimal example to reproduce the problem.

"""Test for NCHW[x]c convolution"""

import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
from tvm.topi.nn.util import get_pad_tuple
from tvm.topi.util import get_const_tuple
import os

import tvm
from tvm import te
from tvm import autotvm

from tvm.topi.cuda.injective import schedule_injective_from_existing
from tvm.topi.cuda.tensor_intrin import dp4a
from tvm.topi.nn.pad import pad
from tvm.topi.nn.util import get_pad_tuple3d
from tvm.topi.util import simplify, get_const_tuple, traverse_inline, tag

##########################################################################
#################### Operator and scheduler definition ###################
##########################################################################


def unpack_NCDHWc_to_ncdhw(packed_out, out_dtype):
    """Unpack conv3d_NCDHWc output from layout NCDHWc to NCDHW

    Parameters
    ----------
    packed_out : tvm.te.Tensor
        The output tensor of conv2d_NCHWc.

    out_dtype : str
        The output dtype.

    Returns
    -------
    unpacked_out : tvm.te.Tensor
        The unpacked output tensor in NCHW layout.
    """
    ######################################")

    n, oc_chunk, oz, oh, ow, oc_bn = get_const_tuple(packed_out.shape)

    idxmod = tvm.tir.indexmod
    idxdiv = tvm.tir.indexdiv

    oshape = (n, oc_chunk * oc_bn, oz, oh, ow)
    unpacked_out = te.compute(
        oshape,
        lambda n, c, z, h, w: packed_out[n, idxdiv(c, oc_bn), z, h, w, idxmod(c, oc_bn)].astype(
            out_dtype
        ),
        name="output_unpack",
        tag=tag.INJECTIVE + ",unpack_ncdhwc",
    )
    return unpacked_out


def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"):
    """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype"""
    assert data.dtype in ("int8", "uint8")
    assert kernel.dtype in ("int8", "uint8")
    assert data.dtype == kernel.dtype
    packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, "NCDHW", out_dtype)
    return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype)


def schedule_conv3d_ncdhw_int8(outs):
    """Create schedule for tensors"""
    return schedule_conv3d_NCDHWc_int8(outs)


def conv3d_NCDHWc_int8(data, kernel, stride, padding, dilation, layout, out_dtype):
    """Convolution operator in NCDHW[x]c layout for int8."""
    cfg = autotvm.get_config()

    assert layout in ["NCDHW", "NCDHW4c"]

    ic_block_factor = 4
    oc_block_factor = 4

    pre_computed = len(kernel.shape) == 7
    if not pre_computed:
        batch, channels, depth, height, width = get_const_tuple(data.shape)
        assert (
            channels % ic_block_factor == 0
        ), "Number of input channels should be multiple of {}".format(ic_block_factor)
        packed_data = te.compute(
            (batch, channels // ic_block_factor, depth, height, width, ic_block_factor),
            lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, w],
            name="packed_data",
        )

        out_channels, in_channels, kernel_d, kernel_h, kernel_w = get_const_tuple(kernel.shape)
        assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format(
            oc_block_factor
        )
        packed_kernel = te.compute(
            (
                out_channels // oc_block_factor,
                in_channels // ic_block_factor,
                kernel_d,
                kernel_h,
                kernel_w,
                oc_block_factor,
                ic_block_factor,
            ),
            lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[
                oc_chunk * oc_block_factor + oc_block,
                ic_chunk * ic_block_factor + ic_block,
                kd,
                kh,
                kw,
            ],
            name="packed_kernel",
        )

    else:
        packed_data = data
        packed_kernel = kernel

    batch, ic_chunk, in_depth, in_height, in_width, ic_block = get_const_tuple(packed_data.shape)
    oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
        packed_kernel.shape
    )
    assert isinstance(stride, int) or len(stride) == 3
    assert isinstance(dilation, int) or len(dilation) == 3

    if isinstance(stride, int):
        stride_d = stride_h = stride_w = stride
    else:
        stride_d, stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_d = dilation_h = dilation_w = dilation
    else:
        dilation_d, dilation_h, dilation_w = dilation

    # # compute the output shape

    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
        padding, (kernel_d, kernel_h, kernel_w)
    )
    # out_channel = num_filter
    out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1
    out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1
    out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1

    oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block)
    # compute graph
    pad_before = [0, 0, pad_front, pad_top, pad_left, 0]
    pad_after = [0, 0, pad_back, pad_down, pad_right, 0]
    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
    icb = te.reduce_axis((0, ic_block), name="ic_block")
    rz = te.reduce_axis((0, kernel_d), name="rz")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")

    conv = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum(
            pad_data[
                nn,
                icc,
                zz * stride_d + rz * dilation_d,
                yy * stride_h + ry * dilation_h,
                xx * stride_w + rx * dilation_w,
                icb,
            ].astype("int32")
            * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, icb].astype("int32"),
            axis=[icc, rz, ry, rx, icb],
        ),
    )

    output = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, xx, oc_block].astype(
            out_dtype
        ),
        tag="conv3d_NCDHWc_int8",
    )

    # num flop
    num_flop = (
        batch
        * oc_chunk
        * oc_block
        * out_height
        * out_width
        * ic_chunk
        * ic_block
        * kernel_d
        * kernel_h
        * kernel_w
        * 2
    )
    cfg.add_flop(num_flop)

    return output


_dp4a = dp4a("shared", "shared", "local")


def schedule_conv3d_NCDHWc_int8(outs):
    """Schedule conv3d int8 NCDHWc template"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "conv3d_NCDHWc_int8":
            _schedule_conv3d_NCDHWc_int8(s, op.output(0), "NCDHW", "conv3d_NCDHWc_int8.cuda")

    traverse_inline(s, outs[0].op, _callback)
    return s


def _schedule_conv3d_NCDHWc_int8(s, output, layout, workload_name):

    cfg = autotvm.get_config()

    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)
    if pad_data != packed_data:
        s[pad_data].compute_inline()

    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 6:
        n, f, d, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, d, y, x = s[output].op.axis

    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_d", cfg.axis(d), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    kernel_scope, n = s[output].split(n, nparts=1)

    # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi)

    bf = s[output].fuse(n, bf)

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(bd, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vd, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tf, te.thread_axis("threadIdx.z"))
        s[output].bind(td, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_f"].size[2]
        n_ty = cfg["tile_d"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, d, y, x, c = s[conv].op.axis
    rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis

    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    rdo, rdi = cfg["tile_rd"].apply(s, conv, rd)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
    s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # # cooperative fetching
    for load in [AA, WW]:

        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s


##########################################################################
############################## Testing part ##############################
##########################################################################


@autotvm.template("tutorial/conv3d_int8")
def topi_conv(
    batch,
    in_channel,
    in_size,
    time_dim,
    num_filter,
    kernel,
    stride,
    padding,
    dilation=1,
    add_bias=False,
    add_relu=False,
    dtype="float32",
):

    A = te.placeholder(
        (batch, in_channel, time_dim, in_size, in_size),
        name="A",
        dtype="int8",
    )
    W = te.placeholder(
        (
            num_filter,
            in_channel,
            kernel,
            kernel,
            kernel,
        ),
        name="W",
        dtype="int8",
    )
    out = conv3d_ncdhw_int8(
        A,
        W,
        (stride, stride, stride),
        (padding, padding, padding),
        (dilation, dilation, dilation),
    )
    s = schedule_conv3d_NCDHWc_int8([out])

    # you can uncomment this line to see the generated code
    # print(tvm.lower(s, [A, W, out]))

    return s, [A, W, out]


(batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation) = (
    1,
    128,
    56,
    18,
    128,
    3,
    1,
    1,
    1,
)
A = te.placeholder((batch, in_channel, time_dim, in_size, in_size), name="A")

W = te.placeholder(
    (
        num_filter,
        in_channel,
        kernel,
        kernel,
        kernel,
    ),
    name="W",
)
target = "cuda"


task = autotvm.task.create(
    "tutorial/conv3d_int8",
    (batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation),
    target=target,
)

measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=10, timeout=50),
)
tuner = autotvm.tuner.XGBTuner(task)

# uncomment if you want to tune the 3d convolution

# tuner.tune(
#     n_trial=10,
#     measure_option=measure_option,
#     callbacks=[
#         autotvm.callback.progress_bar(10, prefix="convolution"),
#         autotvm.callback.log_to_file("convolution.log"),
#     ],
# )

# A example of configuration that does not work:
# {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"}

with autotvm.apply_history_best("convolution.log"):
    with tvm.target.Target(target):
        s, arg_bufs = topi_conv(
            batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation
        )
        func = tvm.build(s, arg_bufs, target=target)

The tuner of tvm will explore several configurations, and pick the best one.

For reproducibility purpose, I also provide you a example of configuration that makes tvm crash:

{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"}

You can write this configuration manually in the log file, and you will get the following result

[12:36:52] /usr/tvm/src/tir/transforms/loop_partition.cc:548: Cannot prove: ((((((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)) - 1) - (29 - (blockIdx.z*4))) + 1) >= 0), when generating the post doubt loop
Traceback (most recent call last):
  File "test_3dconv_optimization.py", line 491, in <module>
    func = tvm.build(s, arg_bufs, target=target)
  File "/usr/tvm/python/tvm/driver/build_module.py", line 414, in build
    mod_host, mdev = _build_for_device(input_mod, tar, target_host)
  File "/usr/tvm/python/tvm/driver/build_module.py", line 256, in _build_for_device
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
  File "/usr/tvm/python/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (6) /usr/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f3e6c5faf51]
  [bt] (5) /usr/tvm/build/libtvm.so(+0x644907) [0x7f3e6ba2f907]
  [bt] (4) /usr/tvm/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x40e) [0x7f3e6ba2ed1e]
  [bt] (3) /usr/tvm/build/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1e2) [0x7f3e6ba2ce52]
  [bt] (2) /usr/tvm/build/libtvm.so(+0x8d347c) [0x7f3e6bcbe47c]
  [bt] (1) /usr/tvm/build/libtvm.so(tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)+0x2d19) [0x7f3e6bcbb7a9]
  [bt] (0) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x61) [0x7f3e6b925f91]
  File "/usr/tvm/src/tir/transforms/make_packed_api.cc", line 210
TVMError: Not all Vars are passed in api_args:  'threadIdx.z'  is not bound to any variables

As mentioned previously, you make this configuration valid by changing the value of tile_f into

{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 1]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"}
1 Like

I believe this line is the issue as it occurs before threadIdx.z is defined.

However, I cannot reproduce this issue with the script you’ve given me. Are you up to date on master? I used c7c39a492a51736dbd780fa12021fee6097bfe88.

Thank you for your quick reply.

I cloned the last version (55d81720f3d05bce559d8b4d7972f54b0fa3eb60). I slightly modify the script because some files got renamed (util => utils).

"""Test for NCHW[x]c convolution"""

import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
from tvm.topi.nn.utils import get_pad_tuple
from tvm.topi.utils import get_const_tuple
import os

import tvm
from tvm import te
from tvm import autotvm

from tvm.topi.cuda.injective import schedule_injective_from_existing
from tvm.topi.cuda.tensor_intrin import dp4a
from tvm.topi.nn.pad import pad
from tvm.topi.nn.utils import get_pad_tuple3d
from tvm.topi.utils import simplify, get_const_tuple, traverse_inline, tag

##########################################################################
#################### Operator and scheduler definition ###################
##########################################################################


def unpack_NCDHWc_to_ncdhw(packed_out, out_dtype):
    """Unpack conv3d_NCDHWc output from layout NCDHWc to NCDHW

    Parameters
    ----------
    packed_out : tvm.te.Tensor
        The output tensor of conv2d_NCHWc.

    out_dtype : str
        The output dtype.

    Returns
    -------
    unpacked_out : tvm.te.Tensor
        The unpacked output tensor in NCHW layout.
    """
    ######################################")

    n, oc_chunk, oz, oh, ow, oc_bn = get_const_tuple(packed_out.shape)

    idxmod = tvm.tir.indexmod
    idxdiv = tvm.tir.indexdiv

    oshape = (n, oc_chunk * oc_bn, oz, oh, ow)
    unpacked_out = te.compute(
        oshape,
        lambda n, c, z, h, w: packed_out[n, idxdiv(c, oc_bn), z, h, w, idxmod(c, oc_bn)].astype(
            out_dtype
        ),
        name="output_unpack",
        tag=tag.INJECTIVE + ",unpack_ncdhwc",
    )
    return unpacked_out


def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"):
    """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype"""
    assert data.dtype in ("int8", "uint8")
    assert kernel.dtype in ("int8", "uint8")
    assert data.dtype == kernel.dtype
    packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, "NCDHW", out_dtype)
    return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype)


def schedule_conv3d_ncdhw_int8(outs):
    """Create schedule for tensors"""
    return schedule_conv3d_NCDHWc_int8(outs)


def conv3d_NCDHWc_int8(data, kernel, stride, padding, dilation, layout, out_dtype):
    """Convolution operator in NCDHW[x]c layout for int8."""
    cfg = autotvm.get_config()

    assert layout in ["NCDHW", "NCDHW4c"]

    ic_block_factor = 4
    oc_block_factor = 4

    pre_computed = len(kernel.shape) == 7
    if not pre_computed:
        batch, channels, depth, height, width = get_const_tuple(data.shape)
        assert (
            channels % ic_block_factor == 0
        ), "Number of input channels should be multiple of {}".format(ic_block_factor)
        packed_data = te.compute(
            (batch, channels // ic_block_factor, depth, height, width, ic_block_factor),
            lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, w],
            name="packed_data",
        )

        out_channels, in_channels, kernel_d, kernel_h, kernel_w = get_const_tuple(kernel.shape)
        assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format(
            oc_block_factor
        )
        packed_kernel = te.compute(
            (
                out_channels // oc_block_factor,
                in_channels // ic_block_factor,
                kernel_d,
                kernel_h,
                kernel_w,
                oc_block_factor,
                ic_block_factor,
            ),
            lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[
                oc_chunk * oc_block_factor + oc_block,
                ic_chunk * ic_block_factor + ic_block,
                kd,
                kh,
                kw,
            ],
            name="packed_kernel",
        )

    else:
        packed_data = data
        packed_kernel = kernel

    batch, ic_chunk, in_depth, in_height, in_width, ic_block = get_const_tuple(packed_data.shape)
    oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
        packed_kernel.shape
    )
    assert isinstance(stride, int) or len(stride) == 3
    assert isinstance(dilation, int) or len(dilation) == 3

    if isinstance(stride, int):
        stride_d = stride_h = stride_w = stride
    else:
        stride_d, stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_d = dilation_h = dilation_w = dilation
    else:
        dilation_d, dilation_h, dilation_w = dilation

    # # compute the output shape

    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
        padding, (kernel_d, kernel_h, kernel_w)
    )
    # out_channel = num_filter
    out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1
    out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1
    out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1

    oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block)
    # compute graph
    pad_before = [0, 0, pad_front, pad_top, pad_left, 0]
    pad_after = [0, 0, pad_back, pad_down, pad_right, 0]
    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
    icb = te.reduce_axis((0, ic_block), name="ic_block")
    rz = te.reduce_axis((0, kernel_d), name="rz")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")

    conv = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum(
            pad_data[
                nn,
                icc,
                zz * stride_d + rz * dilation_d,
                yy * stride_h + ry * dilation_h,
                xx * stride_w + rx * dilation_w,
                icb,
            ].astype("int32")
            * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, icb].astype("int32"),
            axis=[icc, rz, ry, rx, icb],
        ),
    )

    output = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, xx, oc_block].astype(
            out_dtype
        ),
        tag="conv3d_NCDHWc_int8",
    )

    # num flop
    num_flop = (
        batch
        * oc_chunk
        * oc_block
        * out_height
        * out_width
        * ic_chunk
        * ic_block
        * kernel_d
        * kernel_h
        * kernel_w
        * 2
    )
    cfg.add_flop(num_flop)

    return output


_dp4a = dp4a("shared", "shared", "local")


def schedule_conv3d_NCDHWc_int8(outs):
    """Schedule conv3d int8 NCDHWc template"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "conv3d_NCDHWc_int8":
            _schedule_conv3d_NCDHWc_int8(s, op.output(0), "NCDHW", "conv3d_NCDHWc_int8.cuda")

    traverse_inline(s, outs[0].op, _callback)
    return s


def _schedule_conv3d_NCDHWc_int8(s, output, layout, workload_name):

    cfg = autotvm.get_config()

    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)
    if pad_data != packed_data:
        s[pad_data].compute_inline()

    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 6:
        n, f, d, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, d, y, x = s[output].op.axis

    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_d", cfg.axis(d), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    kernel_scope, n = s[output].split(n, nparts=1)

    # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi)

    bf = s[output].fuse(n, bf)

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(bd, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vd, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tf, te.thread_axis("threadIdx.z"))
        s[output].bind(td, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_f"].size[2]
        n_ty = cfg["tile_d"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, d, y, x, c = s[conv].op.axis
    rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis

    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    rdo, rdi = cfg["tile_rd"].apply(s, conv, rd)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
    s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # # cooperative fetching
    for load in [AA, WW]:

        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s


##########################################################################
############################## Testing part ##############################
##########################################################################


@autotvm.template("tutorial/conv3d_int8")
def topi_conv(
    batch,
    in_channel,
    in_size,
    time_dim,
    num_filter,
    kernel,
    stride,
    padding,
    dilation=1,
    add_bias=False,
    add_relu=False,
    dtype="float32",
):

    A = te.placeholder(
        (batch, in_channel, time_dim, in_size, in_size),
        name="A",
        dtype="int8",
    )
    W = te.placeholder(
        (
            num_filter,
            in_channel,
            kernel,
            kernel,
            kernel,
        ),
        name="W",
        dtype="int8",
    )
    out = conv3d_ncdhw_int8(
        A,
        W,
        (stride, stride, stride),
        (padding, padding, padding),
        (dilation, dilation, dilation),
    )
    s = schedule_conv3d_NCDHWc_int8([out])

    # you can uncomment this line to see the generated code
    # print(tvm.lower(s, [A, W, out]))

    return s, [A, W, out]


(batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation) = (
    1,
    128,
    56,
    18,
    128,
    3,
    1,
    1,
    1,
)
A = te.placeholder((batch, in_channel, time_dim, in_size, in_size), name="A")

W = te.placeholder(
    (
        num_filter,
        in_channel,
        kernel,
        kernel,
        kernel,
    ),
    name="W",
)
target = "cuda"


task = autotvm.task.create(
    "tutorial/conv3d_int8",
    (batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation),
    target=target,
)

measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=10, timeout=50),
)
tuner = autotvm.tuner.XGBTuner(task)

# uncomment if you want to tune it

# tuner.tune(
#     n_trial=10,
#     measure_option=measure_option,
#     callbacks=[
#         autotvm.callback.progress_bar(10, prefix="convolution"),
#         autotvm.callback.log_to_file("convolution.log"),
#     ],
# )

# A example of configuration that do not work:
# {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev0"}

with autotvm.apply_history_best("convolution_test.log"):
    with tvm.target.Target(target):
        s, arg_bufs = topi_conv(
            batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation
        )
        func = tvm.build(s, arg_bufs, target=target)

In addition, there was a typo regarding tvm version (0.8.dev1 => 0.8.dev0). The faulty configuration mentioned previously became:

{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev0"}

I run my code in a docker created with the TVM dockerfile, and I still get the error mentioned in my previous post.

1 Like

I can reproduce it now. To me it looks like a bug in scheduling. Maybe @tqchen knows why this is happening?

1 Like

I am less familar with this part of the code, cc @vinx13 who might know a bit more

Do you have any advice for me to solve this bug ? In which file should I look for ?