Could tir transform support cache read to different dtype, like float16 to int32?

For some case, we may need to read in like float16, and coalesce into longer type like int32 for better performance.

But current cache_read API seem directly use the original block’s dtype, so how to tell tir try different dtype for the transform?

Take transform as example here. Like original transform tir code as:

@I.ir_module
class ModC:
    @T.prim_func
    def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(64), T.int64(1088), T.int64(1920), ), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)):
            with T.block("T_transpose"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2]

    @R.function
    def main(x: R.Tensor((1, 64, 1088, 1920), dtype="float16")) -> R.Tensor((1, 1088, 1920, 64), dtype="float16"):
        R.func_attr({"num_input": 1})
        cls = ModC
        with R.dataflow():
            gv = R.call_tir(cls.transpose, (x,), out_sinfo=R.Tensor((1, 1088, 1920, 64), dtype="float16"))
            R.output(gv)
        return gv

How could I transform it into code like below?

@I.ir_module
class ModB:
    @T.prim_func(private=True)
    def transpose(A: T.Buffer((T.int64(1), T.int64(64), T.int64(1088), T.int64(1920)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(1088), T.int64(1920), T.int64(64)), "float16")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for bid in T.thread_binding(T.int64(32640), thread="blockIdx.x"):
            for tix in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for tiy in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    with T.block("T_transpose"):
                        tmp = T.alloc_buffer((2113,), "int32", scope="shared")
                        base = 0x80 // 2 * bid + 128 * 32640 * tiy
                        A_1 = T.Buffer((T.int64(1920*1088*64),), "float16", data=A.data)
                        B_1 = T.Buffer((T.int64(1920*1088*64//2),), "int32", data=T_transpose.data)


                        buf0 = T.decl_buffer([4], "float16", scope="local")
                        buf1 = T.Buffer((T.int64(2),), "int32", data=buf0.data)
                        buf0[0] = A_1[base + tix + 32 * 0]
                        buf0[2] = A_1[base + tix + 32 * 1]
                        buf0[1] = A_1[base + tix + 32 * 2 + (32640 - 1) * 64]
                        buf0[3] = A_1[base + tix + 32 * 3 + (32640 - 1) * 64]

                        tmp[tiy + 33 * tix + 0x1080//4 * 0] = buf1[0]
                        tmp[tiy + 33 * tix + 0x1080//4 * 1] = buf1[1]

                        buf1[0] = tmp[tix + 33 * tiy + 0x1080//4 * 0]
                        buf1[1] = tmp[tix + 33 * tiy + 0x1080//4 * 1]

                        B_1[0x2000//4*bid+tix+tiy*32 + 0x1000//4 * 0] = buf1[0]
                        B_1[0x2000//4*bid+tix+tiy*32 + 0x1000//4 * 1] = buf1[1]


    @R.function
    def main(x: R.Tensor((1, 64, 1088, 1920), dtype="float16")) -> R.Tensor((1, 1088, 1920, 64), dtype="float16"):
        R.func_attr({"num_input": 1})
        cls = ModB
        with R.dataflow():
            gv = R.call_tir(cls.transpose, (x,), out_sinfo=R.Tensor((1, 1088, 1920, 64), dtype="float16"))
            R.output(gv)

I know there is an unsafe_set_dtype function Maybe that does what you need?

From the testcase in https://github.com/apache/tvm/blob/main/tests/python/tir-schedule/test_tir_schedule_set_dtype.py, it seems to me that unsafe_set_dtype is used to force cast dtype.

But what I am seeking for is more likely coalesce multiple small storage elements into a larger element. So my example could be deemed as how to merge two 16bits read/write into one 32bits read/write.