Confused about cache related tir schedule primitives

Confused about tir cache related schedule primitives

Pick up matrix multiplication as example:

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K])
        B = T.match_buffer(b, [K, N])
        C = T.match_buffer(c, [M, N])

        for i, j, k in T.grid(M, K, N):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Simply do cache_write schedule:

M = N = K = 16384


@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K])
        B = T.match_buffer(b, [K, N])
        C = T.match_buffer(c, [M, N])

        for i, j, k in T.grid(M, K, N):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)

block_h = 32
block_w = 32

block_b = sch.get_block("B")
(i, j, k) = sch.get_loops(block_b)
by, yi = sch.split(i, factors=[None, block_h])
bx, xi = sch.split(j, factors=[None, block_w])
sch.reorder(by, bx, yi, xi)
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(yi, "threadIdx.y")
sch.bind(xi, "threadIdx.x")
block_cl = sch.cache_write(block_b, 0, "local")
sch.compute_at(block_cl, yi, preserve_unit_loops=True)

we also need to do compute_at because currently “C” is not contained in a thread environment or in the function arguments, but I got an issue when execute the compute_at schedule:

Traceback (most recent call last):
  File "/workspace/v-leiwang3/tvm_gpu_gemm/tensorir_/1.blocked_gemm.py", line 77, in <module>
    sch.compute_at(block_cl, yi, preserve_unit_loops=True)
  File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/_type_checker.py", line 274, in wrap
    return func(*args, **kwargs)
  File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/schedule.py", line 1349, in compute_at
    preserve_unit_loops,
  File "/workspace/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  4: TVMFuncCall
  3: _ZN3tvm7runtime13PackedFun
  2: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, void>(void (tvm::tir::ScheduleNode::*)(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool))::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)#1}>(tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, void>(void (tvm::tir::ScheduleNode::*)(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool))::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
  0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(16384, 16384), "float32"], B: T.Buffer[(16384, 16384), "float32"], C: T.Buffer[(16384, 16384), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        C_local = T.alloc_buffer([16384, 16384], dtype="float32", scope="local")
        for i_0 in T.thread_binding(512, thread="blockIdx.y"):
            for j_0 in T.thread_binding(512, thread="blockIdx.x"):
                for i_1 in T.thread_binding(32, thread="threadIdx.y"):
                    for j_1 in T.thread_binding(32, thread="threadIdx.x"):
                        for k in T.serial(16384):
                            with T.block("B"):
                                vi = T.axis.spatial(16384, i_0 * 32 + i_1)
                                vj = T.axis.spatial(16384, j_0 * 32 + j_1)
                                vk = T.axis.reduce(16384, k)
                                T.reads(A[vi, vk], B[vk, vj])
                                T.writes(C_local[vi, vj])
                                with T.init():
                                    C_local[vi, vj] = T.float32(0)
                                C_local[vi, vj] = C_local[vi, vj] + A[vi, vk] * B[vk, vj]
        for ax0, ax1 in T.grid(16384, 16384):
            # tir.Block#0
            with T.block("C_local"):
            ^^^^^^^^^^^^^^^^^^^^^^^^
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(C_local[v0, v1])
                T.writes(C[v0, v1])
                C[v0, v1] = C_local[v0, v1]
    
Error message: The block tir.Block#0 is an output block

I rewrote the code according to my understanding and it works:

block_b = sch.get_block("B")
block_cl = sch.cache_write(block_b, 0, "local")
(i, j) = sch.get_loops(block_cl)
by, yi = sch.split(i, factors=[None, block_h])
bx, xi = sch.split(j, factors=[None, block_w])
sch.reorder(by, bx, yi, xi)
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(yi, "threadIdx.y")
sch.bind(xi, "threadIdx.x")
sch.compute_at(block_b, xi, preserve_unit_loops=True)

What I confused is the code that doesn’t work is a variant from my tensor expression experience:

    nn = 16384
    n = te.var("n")
    n = tvm.runtime.convert(nn)
    m, l = n, n
    A = te.placeholder((l, n), dtype=_dtype, name="A")
    B = te.placeholder((l, m), dtype=_dtype, name="B")
    k = te.reduce_axis((0, l), name="k")
    C = te.compute((m, n), lambda ii, jj: te.sum(
        A[k, jj] * B[k, ii], axis=k), name="C")
    # schedule
    s = te.create_schedule(C.op)
    CC = s.cache_write(C, "local")
    write_code(
        str(tvm.lower(s, [A, B, C], simple_mode=True)), "progress/1.cache.cu")

    block_x = te.thread_axis("blockIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    thread_x = te.thread_axis("threadIdx.x")
    thread_y = te.thread_axis("threadIdx.y")

    block_h = 32
    block_w = 32

    bx, xi = s[C].split(C.op.axis[0], factor=(block_h))
    by, yi = s[C].split(C.op.axis[1], factor=(block_w))
    s[C].bind(bx, block_x)
    s[C].bind(by, block_y)
    s[C].reorder(bx, by, xi, yi)
    s[C].bind(xi, thread_x)
    s[C].bind(yi, thread_y)
    s[CC].compute_at(s[C], yi)

What I confused in fact is why we need a rule to check weather a block is an output block when we do a compute_at schedule?

// tvm/src/tir/schedule/primitive/compute_at.cc:557
// Check condition 4): `block` is not an output block
  if (is_compute_at) {
    CheckNotOutputBlock(self, block_sref, scope_root_sref);
  }

You’re right. compute_at in tir schedule is a little different with that in te schedule: We have two primitive compute_at and reverse_compute_at while te mixes two primtives into one compute_at

Thanks siyuan @Hzfengsy , this helps me a lot.

Hi siyuan, now I have another few questions about tir cache related primitives.

The code I showed below can be accessed by this link: https://github.com/LeiWang1999/tvm_gpu_gemm/blob/6ab46c0596ffe10ec7e1940ec0b1bcf9971c6b8f/tensorir/2.thread_tiling.py

The first one is about cache_read primitive.

'''
read_buffer_index : 0->A 1->B
'''
block_b = sch.get_block("B")
block_shared_A = sch.cache_read(block_b, 0, "shared")
block_local_A = sch.cache_read(block_b, 0, "local")
block_shared_B = sch.cache_read(block_b, 1, "shared")
block_local_B = sch.cache_read(block_b, 1, "local")
block_cl = sch.cache_write(block_b, 0, "local")

At the very first start, I tried to use cache_read(block_b, 0, "local") but I found it only cached A, I guess the second param of cache_read , read_buffer_index can solve this problem, but no documentation guided me that, so maybe we better need a demo or docs or is there any better solution?

anyway, what confused me about this code is whether we cached data from global memory directly to local, or from share memory to local, the first param BlockRV should be block_b, I tried first with the code below:

block_b = sch.get_block("B")
block_shared_A = sch.cache_read(block_b, 0, "shared")
block_local_A = sch.cache_read(block_shared_A , 0, "local")
block_shared_B = sch.cache_read(block_b, 1, "shared")
block_local_B = sch.cache_read(block_shared_B , 1, "local")
block_cl = sch.cache_write(block_b, 0, "local")

which is also a variant from tensor expression:

    AA = s.cache_read(A, "shared", [C])
    BB = s.cache_read(B, "shared", [C])
    AL = s.cache_read(AA, "local", [C])
    BL = s.cache_read(BB, "local", [C])

and which makes more sense to me, could you kindly tell me why is it designed this way?

The second one is another question about compute_at

to compute_at the A_shared, A_local to the right place, I found that we must do local register’s compute_at first, then do shared memory compute_at, if we reverse the order:

sch.compute_at(block_shared_A, ko)
sch.compute_at(block_shared_B, ko)
sch.compute_at(block_local_A, ki)
sch.compute_at(block_local_B, ki)

in which case we will get an exception : The primitive requires all the consumer(s) of the given block to be present under the target loop. However, there are 1 consumer(s) not satisfying the constraint. List of the consumer(s):tir.Block#0

and the right code is:

sch.compute_at(block_local_A, ki)
sch.compute_at(block_local_B, ki)
sch.compute_at(block_shared_A, ko)
sch.compute_at(block_shared_B, ko)

I think this way brings some programming limitations, because in tensor expression we can do :

    s[AA].compute_at(s[CC], ko)
    s[BB].compute_at(s[CC], ko)
    s[AL].compute_at(s[CC], ki)
    s[BL].compute_at(s[CC], ki)

The last question is about get_loops

We have a tir script below:

for i_0 in T.thread_binding(128, thread="blockIdx.y"):
        for j_0 in T.thread_binding(128, thread="blockIdx.x"):
            for i_1_0 in T.thread_binding(16, thread="threadIdx.y"):
                for j_1_0 in T.thread_binding(16, thread="threadIdx.x"):
                    for k_0 in T.serial(1024):
                        for ax0, ax1 in T.grid(128, 16):
                            with T.block("A_shared"):
                                v0 = T.axis.spatial(16384, i_0 * 128 + ax0)
                                v1 = T.axis.spatial(16384, k_0 * 16 + ax1)
                                T.reads(A[v0, v1])
                                T.writes(A_shared[v0, v1])
                                A_shared[v0, v1] = A[v0, v1]
                        for ax0, ax1 in T.grid(16, 128):
                            with T.block("B_shared"):
                                v0 = T.axis.spatial(16384, k_0 * 16 + ax0)
                                v1 = T.axis.spatial(16384, j_0 * 128 + ax1)
                                T.reads(B[v0, v1])
                                T.writes(B_shared[v0, v1])
                                B_shared[v0, v1] = B[v0, v1]

I wanna use multi-thread to fetch A_shared from global, to do it in tensor expression, we need to

    aa_tx, aa_xi = s[AA].split(s[AA].op.axis[0], nparts=Block_Size_X)
    aa_ty, aa_yi = s[AA].split(s[AA].op.axis[1], nparts=Block_Size_Y)
    s[AA].reorder(aa_tx, aa_ty, aa_xi, aa_yi)
    s[AA].bind(aa_tx, thread_x)
    s[AA].bind(aa_ty, thread_y)

but in tensor ir schedule, the loops of block_shared_A is not equal to 2, actually in the beginning of the code

block_shared_A = sch.cache_read(block_b, 0, "shared")
block_local_A = sch.cache_read(block_b, 0, "local")
block_shared_B = sch.cache_read(block_b, 1, "shared")
block_local_B = sch.cache_read(block_b, 1, "local")

we have two loops of block_shared_A, but after we use compute_at primitive, we have 7 loops of block_shared_A, but we only need the loops for ax0, ax1 in T.grid(128, 16)

sch.compute_at(block_local_A, ki)
sch.compute_at(block_local_B, ki)
sch.compute_at(block_shared_A, ko)
sch.compute_at(block_shared_B, ko)

aa_yi, aa_xi = sch.get_loops(block_shared_A)[-2:] # loops size is 7
aa_ty, aa_yi = sch.split(aa_yi, factors=[Block_Size_Y, None])
aa_tx, aa_xi = sch.split(aa_xi, factors=[Block_Size_X, None])
sch.reorder(aa_ty, aa_tx, aa_yi, aa_xi)
sch.bind(aa_ty, "threadIdx.y")
sch.bind(aa_tx, "threadIdx.x")

I need to use an ugly way to extract the inner loop of block_shared_A, what confused me is why we 7 loops in block_shared_A after compute_at ki? or what are these loops represents? is there any elegant way to get the last two loops we need indeed?

thanks.

Thank you for bringing this up. Generally, tir schedule is not te, which means we don’t consider to use the same API (but only make things similar). Additionally, the API design is based on the core conception: block in tir but stage in te; lazy mode in te and interactive mode in tir.

To be specific:

cache read

in this case, they are totally the same if you regards block_b to be C,

compute_at

The interactive mode requires validation checks at each step, which means the action order matters. However, TE uses lazy mode. Somehow we can change the order of primitives.

get_loops

I agree that get_loops has different behavior from that in TE. But getting all loops around the block looks better to me, instead of depending on the original loops. Also, that’s the result of interactive mode.

1 Like

The cuda kernel generated by the code I provided has a performance reduction compared with Tensor Expression, because the init stage of C is put on the innermost loop:

  for (int k_0 = 0; k_0 < 1024; ++k_0) {
    __syncthreads();
    for (int ax0_1 = 0; ax0_1 < 8; ++ax0_1) {
      A_shared[(((((int)threadIdx.y) * 128) + (ax0_1 * 16)) + ((int)threadIdx.x))] = A[(((((((int)blockIdx.y) * 2097152) + (((int)threadIdx.y) * 131072)) + (ax0_1 * 16384)) + (k_0 * 16)) + ((int)threadIdx.x))];
    }
    for (int ax1_1 = 0; ax1_1 < 8; ++ax1_1) {
      B_shared[(((((int)threadIdx.y) * 128) + (((int)threadIdx.x) * 8)) + ax1_1)] = B[(((((k_0 * 262144) + (((int)threadIdx.y) * 16384)) + (((int)blockIdx.x) * 128)) + (((int)threadIdx.x) * 8)) + ax1_1)];
    }
    __syncthreads();
    for (int k_1 = 0; k_1 < 16; ++k_1) {
      for (int ax0 = 0; ax0 < 8; ++ax0) {
        A_shared_local[ax0] = A_shared[(((((int)threadIdx.y) * 128) + (ax0 * 16)) + k_1)];
      }
      for (int ax01 = 0; ax01 < 8; ++ax01) {
        B_shared_local[ax01] = B_shared[(((k_1 * 128) + (((int)threadIdx.x) * 8)) + ax01)];
      }
      for (int i_1_1 = 0; i_1_1 < 8; ++i_1_1) {
        for (int j_1_1 = 0; j_1_1 < 8; ++j_1_1) {
          if (((k_0 * 16) + k_1) == 0) {
            C_local[((i_1_1 * 8) + j_1_1)] = 0.000000e+00f;
          }
          C_local[((i_1_1 * 8) + j_1_1)] = (C_local[((i_1_1 * 8) + j_1_1)] + (A_shared_local[i_1_1] * B_shared_local[j_1_1]));
        }
      }
    }
  }

while the same schedule in te, which inits C_local at the beginning:

  for (int ii_c_init = 0; ii_c_init < 8; ++ii_c_init) {
    for (int jj_c_init = 0; jj_c_init < 8; ++jj_c_init) {
      C_local[((ii_c_init * 8) + jj_c_init)] = 0.000000e+00f;
    }
  }

It takes extra about 100ms overhead when we generate code with the shape (16384, 16384, 16384), I think we should compute T.init() at the outermost loop, but I didn’t find any guidlines to catch the T.init() block.

        for i, j, k in T.grid(M, K, N):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

done, solved by sch.decompose_reduction(block_b, k)