Sure! This is the code to define the schedule:
dim_I = 128
dim_K = 256
dim_J = 64
factor = 16
inp_shape = (dim_I, dim_K)
wght_shape = (dim_K, dim_J)
out_shape = (dim_I, dim_J)
# calculate inp @ wght + bias = out
inp = te.placeholder(inp_shape, dtype="int8", name="a_in")
wght = te.placeholder(wght_shape, dtype="int8", name="b_in")
bias = te.placeholder((out_shape), dtype="int32", name="bias_in")
rk = te.reduce_axis((0, dim_K), name="k_o")
res = te.compute(
out_shape,
lambda i, j: te.sum(
inp[i, rk].astype(ENV.inp_dtype) * wght[rk, j].astype(ENV.inp_dtype)
+ bias[i, j].astype(ENV.inp_dtype),
axis=[rk],
),
name="res",
tag="dense",
)
sch = te.create_schedule(res.op)
dense_op = res.op
data, weight, bias_op = res.op.input_tensors
i, j = sch[res].op.axis
(k_axis,) = sch[res].op.reduce_axis
# Create Data Locality
sch[bias_op].set_scope(ENV.acc_scope)
cdata = sch.cache_read(inp, ENV.scr_scope, [dense_op])
cweight = sch.cache_read(wght, ENV.scr_wgt_scope, [dense_op])
res_local = sch.cache_write(res, ENV.acc_scope) #source = sch.cache_write(dest)
(i_axis, j_axis) = sch[res_local].op.axis
(k_axis_int,) = sch[res_local].op.reduce_axis
i_axis_outer, i_axis_inner = sch[res_local].split(i_axis, factor=factor)
j_axis_outer, j_axis_inner = sch[res_local].split(j_axis, factor=factor)
k_axis_outer, k_axis_inner = sch[res_local].split(k_axis, factor=factor)
sch[res_local].reorder(i_axis_outer, j_axis_outer, k_axis_outer,
k_axis_inner, j_axis_inner, i_axis_inner)
sch[cdata].compute_at(sch[res_local], k_axis_outer)
sch[cweight].compute_at(sch[res_local], k_axis_outer)
sch[bias_op].compute_at(sch[res_local], k_axis_outer)
sch[res].compute_at(sch[res_local], i_axis_outer) //This line causes a crash
print(tvm.lower(sch, [inp, wght, bias, res], simple_mode=True))
And here is the resulting IR (with the one line commented out that causes issues):
class Module:
@T.prim_func
def main(a_in: T.Buffer((128, 256), "int8"), b_in: T.Buffer((256, 64), "int8"), bias_in: T.Buffer((128, 64), "int32"), res: T.Buffer((128, 64), "int8")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
res_local_accumulator = T.allocate([8192], "int8", "local.accumulator")
a_in_local_scratchpad = T.allocate([256], "int8", "local.scratchpad")
b_in_local_scratchpad_weight = T.allocate([256], "int8", "local.scratchpad_weight")
res_local_accumulator_1 = T.Buffer((8192,), "int8", data=res_local_accumulator, scope="local.accumulator", align=4)
for i_c_outer, j_c_outer in T.grid(8, 4):
for j_c_inner_init, i_c_inner_init in T.grid(16, 16):
res_local_accumulator_1[i_c_outer * 1024 + i_c_inner_init * 64 + j_c_outer * 16 + j_c_inner_init] = T.int8(0)
for k_o_outer in range(16):
a_in_local_scratchpad_1 = T.Buffer((256,), "int8", data=a_in_local_scratchpad, scope="local.scratchpad", align=4)
for ax0, ax1 in T.grid(16, 16):
a_in_1 = T.Buffer((32768,), "int8", data=a_in.data)
a_in_local_scratchpad_1[ax0 * 16 + ax1] = a_in_1[i_c_outer * 4096 + ax0 * 256 + k_o_outer * 16 + ax1]
b_in_local_scratchpad_weight_1 = T.Buffer((256,), "int8", data=b_in_local_scratchpad_weight, scope="local.scratchpad_weight", align=4)
for ax0, ax1 in T.grid(16, 16):
b_in_1 = T.Buffer((16384,), "int8", data=b_in.data)
b_in_local_scratchpad_weight_1[ax0 * 16 + ax1] = b_in_1[k_o_outer * 1024 + ax0 * 64 + j_c_outer * 16 + ax1]
for k_o_inner, j_c_inner, i_c_inner in T.grid(16, 16, 16):
cse_var_1: T.int32 = i_c_outer * 1024 + i_c_inner * 64 + j_c_outer * 16 + j_c_inner
bias_in_1 = T.Buffer((8192,), "int32", data=bias_in.data)
res_local_accumulator_1[cse_var_1] = res_local_accumulator_1[cse_var_1] + (a_in_local_scratchpad_1[i_c_inner * 16 + k_o_inner] * b_in_local_scratchpad_weight_1[k_o_inner * 16 + j_c_inner] + T.Cast("int8", bias_in_1[cse_var_1]))
for i, j in T.grid(128, 64):
cse_var_2: T.int32 = i * 64 + j
res_1 = T.Buffer((8192,), "int8", data=res.data)
res_1[cse_var_2] = res_local_accumulator_1[cse_var_2]
I want to move the stage defining res_1
into the k_o_outer
loop.