How to tile buffers to fit into on-chip memory?

Hello,
I am trying to understand how to use TVM for conv2d on a custom accelerator. For this I have created a toy example of a conv2d operation. So far I understand how to tile the computation, reorder loops and how to use cache_read. So far this is the TIR that I can generate:

@I.ir_module
class Module:
    @T.prim_func
    def main(input: T.Buffer((2, 112, 64, 16), "int8"), kernel: T.Buffer((3, 3, 16, 16), "int8"), bias: T.Buffer((16,), "int32"), output: T.Buffer((2, 64, 112, 16), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        res = T.allocate([229376], "int8", "global")
        input_local_inp_buffer = T.allocate([1600], "int8", "local.inp_buffer")
        kernel_local_wgt_buffer = T.allocate([2304], "int8", "local.wgt_buffer")
        bias_local_acc_buffer = T.allocate([16], "int32", "local.acc_buffer")
        for b_o_outer, h_o_outer, w_o_outer in T.grid(2, 8, 14):
            res_1 = T.Buffer((229376,), "int8", data=res)
            for h_o_inner_init, w_o_inner_init, c_o_inner_init in T.grid(8, 8, 16):
                res_1[b_o_outer * 114688 + h_o_outer * 14336 + h_o_inner_init * 1792 + w_o_outer * 128 + w_o_inner_init * 16 + c_o_inner_init] = T.int8(0)
            input_local_inp_buffer_1 = T.Buffer((1600,), "int8", data=input_local_inp_buffer, scope="local.inp_buffer")
            for ax1, ax2 in T.grid(10, 10):
                if T.likely(ax2 // 8 + w_o_outer < 8):
                    for ax3 in range(16):
                        cse_var_1: T.int32 = ax2 * 16
                        input_1 = T.Buffer((229376,), "int8", data=input.data)
                        input_local_inp_buffer_1[ax1 * 160 + cse_var_1 + ax3] = input_1[b_o_outer * 114688 + h_o_outer * 8192 + ax1 * 1024 + w_o_outer * 128 + cse_var_1 + ax3]
            kernel_local_wgt_buffer_1 = T.Buffer((2304,), "int8", data=kernel_local_wgt_buffer, scope="local.wgt_buffer")
            for ax0, ax1, ax2, ax3 in T.grid(3, 3, 16, 16):
                cse_var_2: T.int32 = ax0 * 768 + ax1 * 256 + ax2 * 16 + ax3
                kernel_1 = T.Buffer((2304,), "int8", data=kernel.data)
                kernel_local_wgt_buffer_1[cse_var_2] = kernel_1[cse_var_2]
            bias_local_acc_buffer_1 = T.Buffer((16,), "int32", data=bias_local_acc_buffer, scope="local.acc_buffer")
            for ax0 in range(16):
                bias_1 = T.Buffer((16,), "int32", data=bias.data)
                bias_local_acc_buffer_1[ax0] = bias_1[ax0]
            for h_o_inner, w_o_inner, c_o_inner, rkw_inner, rkh_inner, ric_inner in T.grid(8, 8, 16, 3, 3, 16):
                cse_var_4: T.int32 = w_o_inner * 16
                cse_var_3: T.int32 = b_o_outer * 114688 + h_o_outer * 14336 + h_o_inner * 1792 + w_o_outer * 128 + cse_var_4 + c_o_inner
                res_1[cse_var_3] = res_1[cse_var_3] + (input_local_inp_buffer_1[h_o_inner * 160 + rkh_inner * 160 + cse_var_4 + rkw_inner * 16 + ric_inner] * kernel_local_wgt_buffer_1[rkh_inner * 768 + rkw_inner * 256 + ric_inner * 16 + c_o_inner] + T.Cast("int8", bias_local_acc_buffer_1[c_o_inner]))

Input, weight and bias are tiled partitioned into small buffers that fit the on-chip memory. My problem is the result of the computation. In the IR, res_1 is not affected by my tiling operation and I get an error, that the accumulator size is too small to fit the results. How can fix this?

For the other buffers, here is the code so far:

res = te.compute(
    oshape,
    lambda b_o, h_o, w_o, c_o: te.sum(
        data[b_o, h_o * hstr + rkh, w_o * wstr + rkw, ric].astype(ENV.inp_dtype)
        * kernel[rkh, rkw, ric, c_o].astype(ENV.inp_dtype)
        + bias[c_o].astype(ENV.inp_dtype),
        axis=[rkh, rkw, ric],

    ),
    name="res",
    tag="conv2d",
)
sch = te.create_schedule(res.op)

outer_b, inner_b   = sch[res].split(res.op.axis[0], factor=factor_b)
outer_oh, inner_oh = sch[res].split(res.op.axis[1], factor=factor_hw)
outer_ow, inner_ow = sch[res].split(res.op.axis[2], factor=factor_hw)
outer_oc, inner_oc = sch[res].split(res.op.axis[3], factor=factor_oc)
outer_kh, inner_kh = sch[res].split(res.op.reduce_axis[0], factor=factor_khw)
outer_kw, inner_kw = sch[res].split(res.op.reduce_axis[1], factor=factor_khw)
outer_kc, inner_kc = sch[res].split(res.op.reduce_axis[2], factor=factor_kc)

sch[res].reorder(outer_b, outer_oh, outer_ow, outer_oc, outer_kw, outer_kh, outer_kc,
                 inner_b, inner_oh, inner_ow, inner_oc, inner_kw, inner_kh, inner_kc)

cdata = sch.cache_read(input, "local.inp_buffer", [res])
ckernel = sch.cache_read(kernel, "local.wgt_buffer", [res])
cbias = sch.cache_read(bias, "local.acc_buffer", [res])
sch[res].set_scope("local.acc_buffer")

sch[cdata].compute_at(sch[res], outer_kc)
sch[ckernel].compute_at(sch[res], outer_kc)
sch[cbias].compute_at(sch[res], outer_kc)

I have tried cres = sch.cache_write(res, "accumulator"), but that results in an error: Check failed: iv->iter_type == kDataPar (2 vs. 0) : Can only relayout with in data parallel dimensions

Maybe some more information is helpful. So far I have triedcache_write, which leads to the aforementioned error.
The other thing I tested is just sch[res].set_scope("local.acc_buffer"). But that doesn’t change anything related to buffer layouts. Any help here is greatly appreciated :slight_smile:

Ok, I just stumbled over something. If I change the way I handle the bias stage like this:

cdata = sch.cache_read(input, "local.inp_buffer", [res])
ckernel = sch.cache_read(kernel, "local.wgt_buffer", [res])
cbias = sch.cache_read(bias, "local.acc_buffer", [res])

sch[cdata].compute_at(sch[res], outer_kc)
sch[ckernel].compute_at(sch[res], outer_kc)
sch[cbias].set_scope("local.accumulator")

Then I get no Check failed: const_size * dtype.bits() <= info->max_num_bits (51380224 vs. 2097152) : Allocation exceed bound of memory tag local.accumulator. My question now is why? What does this do differently compared to the cache_read operation? When should you use one or the other?

Any help here would be greatly appreciated :slightly_smiling_face: . I can’t find anything in the documentation about what exactly set_scope does.