Te.gradient not work with complex forward workload

Hey guys.
I am trying to use autoscheduler to generate CUDA source code for backward stage for NCHW winograd_conv2d. Due to some bugs in topi.cuda.conv2d_winograd.winograd_cuda, I copied some code to build my workload.

Luckily, this workload works without te.gradient and can successfully get source code for the forward stage. But when I add te.gradient, this workload no longer works and I get an error msg below: Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, ......

I am really confued now. Forward stage codegen can work proves that my workload is correct in some way. So I think this bug may caused by a bug in TVM, but I am not sure.

I wander if anyone can help me find out what is wrong with my code. Also, I wander if there is a better way to generate backward CUDA source code for NCHW format winograde conv2d.

Thanks a lot!!!

My tvm version is 0.8.0. I build it with the source code from Download Apache TVM Source Code web page.

My code is:


import os

import numpy as np
import tvm
from tvm import auto_scheduler

import logging
from tvm import te, topi
from tvm import autotvm

from tvm.topi import nn
from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline
from tvm.topi.nn.winograd_util import winograd_transform_matrices
from tvm.topi.nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl
import sys
import numpy as np
from tvm.topi.testing import conv2d_nchw_python

def _infer_tile_size(data, kernel, layout="NCHW"):
    if layout == "NCHW":
        N, CI, H, W = get_const_tuple(data.shape)
    else:
        assert layout == "NHWC"
        N, H, W, CI = get_const_tuple(data.shape)

    if H % 8 == 0:
        return 4
    return 2

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")

    stride = (1,1)
    padding = (1,1)
    dilation = (1,1)
    pre_computed = False
    out_dtype = "float32"

    tile_size = _infer_tile_size(data, kernel)
    N, CI, H, W = get_const_tuple(data.shape)

    if isinstance(N, tvm.tir.Any):
        N = tvm.te.size_var("n")

    if not isinstance(H, int) or not isinstance(W, int):
        raise RuntimeError(
            "cuda winograd conv2d doesn't support dynamic input\
                           height or width."
        )

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation
    HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride

    if not pre_computed:  # kernel tensor is raw tensor, do strict check
        if dilation_h != 1 or dilation_w != 1:
            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
        CO, CI, KH, KW = get_const_tuple(kernel.shape)
        alpha = KW + tile_size - 1
        assert HSTR == 1 and WSTR == 1 and KH == KW
    else:
        # kernel tensor is pre-transfomred. this op is created by alter op layout.
        # dilation is not supported
        alpha, _, CI, CO = get_const_tuple(kernel.shape)
        KH = KW = alpha + 1 - tile_size
        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1

    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
    data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")

    r = KW
    m = tile_size
    A, B, G = winograd_transform_matrices(m, r, out_dtype)

    H = (H + pt + pb - KH) // HSTR + 1
    W = (W + pl + pr - KW) // WSTR + 1
    nH, nW = (H + m - 1) // m, (W + m - 1) // m

    P = N * nH * nW if isinstance(N, int) else nH * nW

    # transform kernel
    if not pre_computed:
        r_kh = te.reduce_axis((0, KH), name="r_kh")
        r_kw = te.reduce_axis((0, KW), name="r_kw")
        kernel_pack = te.compute(
            (alpha, alpha, CI, CO),
            lambda eps, nu, ci, co: te.sum(
                kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
            ),
            name="my_kernel_pack",
        )
    else:
        kernel_pack = kernel    
    
    idxdiv = tvm.tir.indexdiv
    idxmod = tvm.tir.indexmod
    # pack input tile
    input_tile = te.compute(
        (CI, P, alpha, alpha),
        lambda c, p, eps_1, nu_1: data_pad[idxdiv(p, (nH * nW))][c][
            idxmod(idxdiv(p, nW), nH) * m + eps_1
        ][idxmod(p, nW) * m + nu_1],
        name="my_d",
    )

    # dy = tvm.te.placeholder(input_tile.shape, name="input2_dy")
    # [dw] = tvm.te.gradient(input_tile, [data], head=dy)
    # return [data, kernel, input_tile, dy, dw]

    # transform data
    r_a = te.reduce_axis((0, alpha), "r_a")
    r_b = te.reduce_axis((0, alpha), "r_b")
    data_pack = te.compute(
        (alpha, alpha, CI, P),
        lambda eps, nu, ci, p: te.sum(
            input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
        ),
        name="my_data_pack",
    )

    # dy = tvm.te.placeholder(data_pack.shape, name="input2_dy")
    # [dw] = tvm.te.gradient(data_pack, [data], head=dy)
    # return [data, kernel, data_pack, dy, dw]

    # do batch gemm
    ci = te.reduce_axis((0, CI), name="ci")
    bgemm = te.compute(
        (alpha, alpha, CO, P),
        lambda eps, nu, co, p: te.sum(
            kernel_pack[eps][nu][ci][co] * data_pack[eps][nu][ci][p], axis=[ci]
        ),
        name="my_bgemm",
    )
    # inverse transform
    r_a_2 = te.reduce_axis((0, alpha), "r_a_2")
    r_b_2 = te.reduce_axis((0, alpha), "r_b_2")
    inverse = te.compute(
        (CO, P, m, m),
        lambda co, p, vh, vw: te.sum(
            bgemm[r_a_2][r_b_2][co][p] * A[r_a_2][vh] * A[r_b_2][vw], axis=[r_a_2, r_b_2]
        ),
        name="my_inverse",
    )

    # output
    output = te.compute(
        (N, CO, H, W),
        lambda n, co, h, w: inverse[
            co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
        ],
        name="my_output",
        tag="conv2d_nchw_winograd",
    )
    
    dy = tvm.te.placeholder(output.shape, name="input2_dy")
    [dw] = tvm.te.gradient(output, [data], head=dy)
    return [data, kernel, output,dy,dw]
    # return [data, kernel, output]

target = tvm.target.Target("cuda")

# Use the last layer in ResNet-50
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

log_file = "conv2d.json"
if os.path.exists(log_file):
    os.remove(log_file)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

# Kill the measurement process
del measure_ctx

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

func = tvm.build(sch, args, target)

print("Equivalent python schedule:")
print(task.print_best(log_file, print_mode="schedule"))

print("CUDA source code:")
print(task.print_best(log_file, print_mode="cuda"))

Part of the log I got:

  1: tvm::auto_scheduler::CacheReadStepNode::ApplyToSchedule(tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*) const
  0: tvm::te::Schedule::cache_read(tvm::te::Tensor const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Array<tvm::te::Operation, void> const&)
  File "/data/apache-tvm-src-v0.8.0.rc0/src/te/schedule/schedule_dataflow_rewrite.cc", line 168
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, body=[extracted_tensor[ax0, ax1, ax2, ax3]], axis=[iter_var(ax0, range(min=0, ext=2)), iter_var(ax1, range(min=0, ext=2)), iter_var(ax2, range(min=0, ext=4)), iter_var(ax3, range(min=0, ext=4))], reduce_axis=[], tag=, attrs={})

I got the same problem with exactly the same TVM version. it seems caused by the schedule stmt “cache_read”.

Do you have any idea about how fix this problem? T_T