For some case, we may need to read in like float16, and coalesce into longer type like int32 for better performance.
But current cache_read API seem directly use the original block’s dtype, so how to tell tir try different dtype for the transform?
Take transform as example here. Like original transform tir code as:
@I.ir_module
class ModC:
@T.prim_func
def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(64), T.int64(1088), T.int64(1920), ), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]
@R.function
def main(x: R.Tensor((1, 64, 1088, 1920), dtype="float16")) -> R.Tensor((1, 1088, 1920, 64), dtype="float16"):
R.func_attr({"num_input": 1})
cls = ModC
with R.dataflow():
gv = R.call_tir(cls.transpose, (x,), out_sinfo=R.Tensor((1, 1088, 1920, 64), dtype="float16"))
R.output(gv)
return gv
How could I transform it into code like below?
@I.ir_module
class ModB:
@T.prim_func(private=True)
def transpose(A: T.Buffer((T.int64(1), T.int64(64), T.int64(1088), T.int64(1920)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)), "float16")):
T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
# with T.block("root"):
for bid in T.thread_binding(T.int64(32640), thread="blockIdx.x"):
for tix in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for tiy in T.thread_binding(T.int64(32), thread="threadIdx.y"):
with T.block("T_transpose"):
tmp = T.alloc_buffer((2113,), "int32", scope="shared")
base = 0x80 // 2 * bid + 128 * 32640 * tiy
A_1 = T.Buffer((T.int64(1920*1088*64),), "float16", data=A.data)
B_1 = T.Buffer((T.int64(1920*1088*64//2),), "int32", data=T_transpose.data)
buf0 = T.decl_buffer([4], "float16", scope="local")
buf1 = T.Buffer((T.int64(2),), "int32", data=buf0.data)
buf0[0] = A_1[base + tix + 32 * 0]
buf0[2] = A_1[base + tix + 32 * 1]
buf0[1] = A_1[base + tix + 32 * 2 + (32640 - 1) * 64]
buf0[3] = A_1[base + tix + 32 * 3 + (32640 - 1) * 64]
tmp[tiy + 33 * tix + 0x1080//4 * 0] = buf1[0]
tmp[tiy + 33 * tix + 0x1080//4 * 1] = buf1[1]
buf1[0] = tmp[tix + 33 * tiy + 0x1080//4 * 0]
buf1[1] = tmp[tix + 33 * tiy + 0x1080//4 * 1]
B_1[0x2000//4*bid+tix+tiy*32 + 0x1000//4 * 0] = buf1[0]
B_1[0x2000//4*bid+tix+tiy*32 + 0x1000//4 * 1] = buf1[1]
@R.function
def main(x: R.Tensor((1, 64, 1088, 1920), dtype="float16")) -> R.Tensor((1, 1088, 1920, 64), dtype="float16"):
R.func_attr({"num_input": 1})
cls = ModB
with R.dataflow():
gv = R.call_tir(cls.transpose, (x,), out_sinfo=R.Tensor((1, 1088, 1920, 64), dtype="float16"))
R.output(gv)