Missing thread bind loops under block reduction when transformed with tir

For example, When we try to implement a reduction schedule plan which required both inter thread reduction and shared memory cache.

import tvm
from tvm.script import ir as I
from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((128, 150528), "float32"), B: T.Buffer((128, 150528), "float32"), C: T.Buffer((128, 128), "float32")):
    T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
    # with T.block("root"):
    C_local = T.alloc_buffer((128, 128), scope="local")
    A_shared = T.alloc_buffer((128, 150528), scope="shared")
    B_shared = T.alloc_buffer((128, 150528), scope="shared")
    for ax0_0_ax1_0_fused in T.thread_binding(256, thread="blockIdx.x"):
        for ax0_1_ax1_1_fused in T.thread_binding(64, thread="threadIdx.y"):
            for ax2_1_1_fused in T.thread_binding(2, thread="threadIdx.x"):
                for ax2_0 in range(392):
                    for ax0_ax1_fused_0 in range(6):
                        for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.serial(4):
                                    with T.block("A_shared"):
                                        v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 384)
                                        v1 = T.axis.spatial(150528, ax2_0 * 384 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 384)
                                        T.reads(A[v0, v1])
                                        T.writes(A_shared[v0, v1])
                                        A_shared[v0, v1] = A[v0, v1]
                    for ax0_ax1_fused_0 in range(6):
                        for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.serial(4):
                                    with T.block("B_shared"):
                                        v0 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 384)
                                        v1 = T.axis.spatial(150528, ax2_0 * 384 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 384)
                                        T.reads(B[v0, v1])
                                        T.writes(B_shared[v0, v1])
                                        B_shared[v0, v1] = B[v0, v1]
                    for ax2_1_0 in range(192):
                        with T.block("B"):
                            v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
                            v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
                            v2 = T.axis.reduce(150528, ax2_0 * 384 + ax2_1_0 * 2 + ax2_1_1_fused)
                            T.reads(A_shared[v0, v2], B_shared[v1, v2])
                            T.writes(C_local[v0, v1])
                            with T.init():
                                C_local[v0, v1] = T.float32(0)
                            C_local[v0, v1] = C_local[v0, v1] + A_shared[v0, v2] * B_shared[v1, v2]
            with T.block("C_local"):
                v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
                v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
                T.reads(C_local[v0, v1])
                T.writes(C[v0, v1])
                C[v0, v1] = C_local[v0, v1]


mod = tvm.IRModule.from_expr(main)
sch = tvm.tir.Schedule(mod, debug_mask="all")
rt_mod = tvm.build(sch.mod, target="cuda")

it will fail, the error message:

Traceback (most recent call last):
  File "/home/t-leiwang/mlc_workspace/unity/python/tvm/dlight/test_tir.py", line 56, in <module>
    dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda")
  File "/home/t-leiwang/ladder_workspace/LadderTVM/python/tvm/driver/build_module.py", line 281, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "/home/t-leiwang/ladder_workspace/LadderTVM/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  12: TVMFuncCall
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform13MakePackedAPIEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  1: tvm::tir::transform::MakePackedAPI()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const [clone .isra.0]
  0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&)
  File "/transforms/make_packed_api.cc", line 295
TVMError: Not all Vars are passed in api_args:  'ax0_ax1_fused_1'  'ax0_ax1_fused_2'  'ax0_ax1_fused_1'  'ax0_ax1_fused_2'  is not bound to any variables

However, the same schedule functions correctlty under tensor expression

looks like this issue was caused by LowerCrossThreadReduction.

the pass will remove all the loops with thread bind under the inter thread reduction block.

PrimFunc([A_handle, B_handle, C_handle]) attrs={"global_symbol": "main", "tir.noalias": (bool)1} {
  block root() {
    reads([])
    writes([])
    C_local = alloc_buffer(float32[128, 128])
    A_shared = alloc_buffer(float32[128, 150528])
    B_shared = alloc_buffer(float32[128, 150528])
    cross_thread_C_local = alloc_buffer(float32[1])
    in_thread_C_local = alloc_buffer(float32[1])
    launch_thread (ax0_0_ax1_0_fused, 0, 256) {
      launch_thread (ax0_1_ax1_1_fused, 0, 64) {
        launch_thread (ax2_1_1_fused, 0, 2) {
          block B_in_thread_init() {
            reads([])
            writes([in_thread_C_local[0]])
            in_thread_C_local[0] = 0f
          }
          for (ax2_0, 0, 392) {
            for (ax0_ax1_fused_0, 0, 6) {
              for (ax0_ax1_fused_3, 0, 4) {
                block A_shared(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=150528))) {
                  bind(v0, ((floordiv(ax0_0_ax1_0_fused, 16)*8) + floordiv(((((ax0_ax1_fused_0*512) + (ax0_ax1_fused_1*8)) + (ax0_ax1_fused_2*4)) + ax0_ax1_fused_3), 384)))
                  bind(v1, ((ax2_0*384) + floormod(((((ax0_ax1_fused_0*512) + (ax0_ax1_fused_1*8)) + (ax0_ax1_fused_2*4)) + ax0_ax1_fused_3), 384)))
                  reads([A[v0, v1]])
                  writes([A_shared[v0, v1]])
                  A_shared[v0, v1] = A[v0, v1]
                }
              }
            }
            for (ax0_ax1_fused_0, 0, 6) {
              for (ax0_ax1_fused_3, 0, 4) {
                block B_shared(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=150528))) {
                  bind(v0, ((floormod(ax0_0_ax1_0_fused, 16)*8) + floordiv(((((ax0_ax1_fused_0*512) + (ax0_ax1_fused_1*8)) + (ax0_ax1_fused_2*4)) + ax0_ax1_fused_3), 384)))
                  bind(v1, ((ax2_0*384) + floormod(((((ax0_ax1_fused_0*512) + (ax0_ax1_fused_1*8)) + (ax0_ax1_fused_2*4)) + ax0_ax1_fused_3), 384)))
                  reads([B[v0, v1]])
                  writes([B_shared[v0, v1]])
                  B_shared[v0, v1] = B[v0, v1]
                }
              }
            }
            for (ax2_1_0, 0, 192) {
              block B_in_thread(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=128)), iter_var(v2, range(min=0, ext=150528))) {
                bind(v0, ((floordiv(ax0_0_ax1_0_fused, 16)*8) + floordiv(ax0_1_ax1_1_fused, 8)))
                bind(v1, ((floormod(ax0_0_ax1_0_fused, 16)*8) + floormod(ax0_1_ax1_1_fused, 8)))
                bind(v2, (((ax2_0*384) + (ax2_1_0*2)) + ax2_1_1_fused))
                reads([A_shared[v0, v2], B_shared[v1, v2]])
                writes([in_thread_C_local[0]])
                in_thread_C_local[0] = (in_thread_C_local[0] + (A_shared[v0, v2]*B_shared[v1, v2]))
              }
            }
          }
          block B_cross_thread() {
            reads([in_thread_C_local[0]])
            writes([cross_thread_C_local[0]])
            // attr [comm_reducer(result=[(x0 + y0)], lhs=[x0], rhs=[y0], identity_element=[0f])] reduce_scope = tir.reinterpret((uint64)0)
            tir.tvm_thread_allreduce((uint32)1, in_thread_C_local[0], (bool)1, cross_thread_C_local[0], ax2_1_1_fused)
          }
          block B_write_back(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=128))) {
            bind(v0, ((floordiv(ax0_0_ax1_0_fused, 16)*8) + floordiv(ax0_1_ax1_1_fused, 8)))
            bind(v1, ((floormod(ax0_0_ax1_0_fused, 16)*8) + floormod(ax0_1_ax1_1_fused, 8)))
            reads([cross_thread_C_local[0]])
            writes([C_local[v0, v1]])
            C_local[v0, v1] = cross_thread_C_local[0]
          }
        }
        block C_local(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=128))) {
          bind(v0, ((floordiv(ax0_0_ax1_0_fused, 16)*8) + floordiv(ax0_1_ax1_1_fused, 8)))
          bind(v1, ((floormod(ax0_0_ax1_0_fused, 16)*8) + floormod(ax0_1_ax1_1_fused, 8)))
          reads([C_local[v0, v1]])
          writes([C[v0, v1]])
          C[v0, v1] = C_local[v0, v1]
        }
      }
    }
  }
}

do we have any specific reasons to remove these loops under block reduction block?

please cc @MasterJH5574

btw, after commenting this function, it can function as expected.

CC @MasterJH5574 if you’d love to take a look

The purpose of the location @LeiWang1999 pointed out is to generate the reduction block that is computed inside a single thread:

            for (ax2_1_0, 0, 192) {
              block B_in_thread(iter_var(v0, range(min=0, ext=128)), iter_var(v1, range(min=0, ext=128)), iter_var(v2, range(min=0, ext=150528))) {
                bind(v0, ((floordiv(ax0_0_ax1_0_fused, 16)*8) + floordiv(ax0_1_ax1_1_fused, 8)))
                bind(v1, ((floormod(ax0_0_ax1_0_fused, 16)*8) + floormod(ax0_1_ax1_1_fused, 8)))
                bind(v2, (((ax2_0*384) + (ax2_1_0*2)) + ax2_1_1_fused))
                reads([A_shared[v0, v2], B_shared[v1, v2]])
                writes([in_thread_C_local[0]])
                in_thread_C_local[0] = (in_thread_C_local[0] + (A_shared[v0, v2]*B_shared[v1, v2]))
              }
            }

This mutator did not consider the case where there could be other non-reduction blocks (like the shared memory copy blocks in your case), and therefore it just simply removes all thread-bound loops, leading to the issue we encounter here.

I think here is a quick fix: before removing a thread-bound loop, check if the block(s) under this loop has reduction block var. If the block(s) under have reduction do not have any reduction block var, it means that block is not reduction, and therefore this thread-bound loop should be kept. Otherwise, we remove the thread-bound loop as usual.

1 Like