"Output must be attached at root" with cache_write

Hello, I am a bit confused and cache_write and how it affects the resulting schedule. I am experimenting with a 3-loop matmul operation, where each loop should be split into two parts, with the computation being tensorized. Without the cache_write operation, this works fine. However, when I add a cache_read like this:

res_local= sch.cache_write(res, "local.acc")

with res being the result of the TE definition of the computation. I now want to move res_local, the move-out of the data, into the compute loop. But when I try to do this, I get Error: Output must be attached at root. Why is this necessary, and can I somehow get around this?

Could you post a sample code to to reproduce the issue? Its a bit hard to understand what the problem is without looking at the IR

Sure! This is the code to define the schedule:

dim_I = 128
dim_K = 256
dim_J = 64

factor = 16

inp_shape =  (dim_I, dim_K)
wght_shape = (dim_K, dim_J)
out_shape =  (dim_I, dim_J)

# calculate inp @ wght + bias = out
inp = te.placeholder(inp_shape, dtype="int8", name="a_in")
wght = te.placeholder(wght_shape, dtype="int8", name="b_in")
bias = te.placeholder((out_shape), dtype="int32", name="bias_in")

rk = te.reduce_axis((0, dim_K), name="k_o")
res = te.compute(
    lambda i, j: te.sum(
        inp[i, rk].astype(ENV.inp_dtype) * wght[rk, j].astype(ENV.inp_dtype)
        + bias[i, j].astype(ENV.inp_dtype),

sch = te.create_schedule(res.op)
dense_op = res.op

data, weight, bias_op = res.op.input_tensors

i, j = sch[res].op.axis
(k_axis,) = sch[res].op.reduce_axis

# Create Data Locality
cdata = sch.cache_read(inp, ENV.scr_scope, [dense_op])
cweight = sch.cache_read(wght, ENV.scr_wgt_scope, [dense_op])
res_local = sch.cache_write(res, ENV.acc_scope) #source = sch.cache_write(dest)

(i_axis, j_axis) = sch[res_local].op.axis
(k_axis_int,) = sch[res_local].op.reduce_axis

i_axis_outer, i_axis_inner = sch[res_local].split(i_axis, factor=factor)
j_axis_outer, j_axis_inner = sch[res_local].split(j_axis, factor=factor)
k_axis_outer, k_axis_inner = sch[res_local].split(k_axis, factor=factor)

sch[res_local].reorder(i_axis_outer, j_axis_outer, k_axis_outer,
    k_axis_inner, j_axis_inner, i_axis_inner)

sch[cdata].compute_at(sch[res_local], k_axis_outer)
sch[cweight].compute_at(sch[res_local], k_axis_outer)
sch[bias_op].compute_at(sch[res_local], k_axis_outer)

sch[res].compute_at(sch[res_local], i_axis_outer) //This line causes a crash

print(tvm.lower(sch, [inp, wght, bias, res], simple_mode=True))

And here is the resulting IR (with the one line commented out that causes issues):

class Module:
    def main(a_in: T.Buffer((128, 256), "int8"), b_in: T.Buffer((256, 64), "int8"), bias_in: T.Buffer((128, 64), "int32"), res: T.Buffer((128, 64), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        res_local_accumulator = T.allocate([8192], "int8", "local.accumulator")
        a_in_local_scratchpad = T.allocate([256], "int8", "local.scratchpad")
        b_in_local_scratchpad_weight = T.allocate([256], "int8", "local.scratchpad_weight")
        res_local_accumulator_1 = T.Buffer((8192,), "int8", data=res_local_accumulator, scope="local.accumulator", align=4)
        for i_c_outer, j_c_outer in T.grid(8, 4):
            for j_c_inner_init, i_c_inner_init in T.grid(16, 16):
                res_local_accumulator_1[i_c_outer * 1024 + i_c_inner_init * 64 + j_c_outer * 16 + j_c_inner_init] = T.int8(0)
            for k_o_outer in range(16):
                a_in_local_scratchpad_1 = T.Buffer((256,), "int8", data=a_in_local_scratchpad, scope="local.scratchpad", align=4)
                for ax0, ax1 in T.grid(16, 16):
                    a_in_1 = T.Buffer((32768,), "int8", data=a_in.data)
                    a_in_local_scratchpad_1[ax0 * 16 + ax1] = a_in_1[i_c_outer * 4096 + ax0 * 256 + k_o_outer * 16 + ax1]
                b_in_local_scratchpad_weight_1 = T.Buffer((256,), "int8", data=b_in_local_scratchpad_weight, scope="local.scratchpad_weight", align=4)
                for ax0, ax1 in T.grid(16, 16):
                    b_in_1 = T.Buffer((16384,), "int8", data=b_in.data)
                    b_in_local_scratchpad_weight_1[ax0 * 16 + ax1] = b_in_1[k_o_outer * 1024 + ax0 * 64 + j_c_outer * 16 + ax1]
                for k_o_inner, j_c_inner, i_c_inner in T.grid(16, 16, 16):
                    cse_var_1: T.int32 = i_c_outer * 1024 + i_c_inner * 64 + j_c_outer * 16 + j_c_inner
                    bias_in_1 = T.Buffer((8192,), "int32", data=bias_in.data)
                    res_local_accumulator_1[cse_var_1] = res_local_accumulator_1[cse_var_1] + (a_in_local_scratchpad_1[i_c_inner * 16 + k_o_inner] * b_in_local_scratchpad_weight_1[k_o_inner * 16 + j_c_inner] + T.Cast("int8", bias_in_1[cse_var_1]))
        for i, j in T.grid(128, 64):
            cse_var_2: T.int32 = i * 64 + j
            res_1 = T.Buffer((8192,), "int8", data=res.data)
            res_1[cse_var_2] = res_local_accumulator_1[cse_var_2]

I want to move the stage defining res_1 into the k_o_outer loop.

I did some experiments and found another post in the forum with the same problem. The solution seems to be to migrate from TE based scheduling to TIR for scheduling, as this doesn’t have the same restriction.

1 Like

Sorry I could not reply earlier as I was on vacation for a while and yes, I would have probably replied the same. Partly also because I’m more familiar with TIR schedules than TE and partly because the working of cache read/write in TIR seems more intuitive to me than TE which I always found a bit confusing.