TIR Schedule Capture iteration variable in intrin

Hi, if I want to add a statement into my defined intrin that depends on an iteration variable, is there a way to do this? Something like this after tensorization is what I’m after:

for(i = 0; i < 10; i++){
    if(i == 0){
        call_function_a(...);
    }else{
        call_function_b(...);
    }
}

In this case, the iter_var i would be the result of some splitting and reorder operation applied to a schedule.

I’ve tried simply moving the iteration variable into the intrin, but that leads to issues when I try to perform the tensorization. Here is the error message:

The pattern attempting to be matched:
    for i_1 in range(16):
        with T.block("res_o_update"):
            v_i_o_i = T.axis.spatial(16, i_1)
            res_local_acc = T.Buffer((512, 256), "int32", scope="local.acc")
            v_i_o_o = T.int32()
            v_j_o = T.int32()
            a_in_local_spad = T.Buffer((512, 768), "int8", scope="local.spad")
            v_k_o = T.int32()
            b_in_local_spad_w = T.Buffer((768, 256), "int8", scope="local.spad_w")
            T.reads(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_j_o * 4:v_j_o * 4 + 4], a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_k_o * 4:v_k_o * 4 + 4], b_in_local_spad_w[v_k_o * 4:v_k_o * 4 + 4, v_j_o * 4:v_j_o * 4 + 4])
            T.writes(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_j_o * 4:v_j_o * 4 + 4])
            for k_2, j_2, i_2 in T.grid(4, 4, 4):
                with T.block("res"):
                    v_i_i, v_j_i, v_k_i = T.axis.remap("SSR", [i_2, j_2, k_2])
                    T.reads(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i], a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_k_o * 4 + v_k_i], b_in_local_spad_w[v_k_o * 4 + v_k_i, v_j_o * 4 + v_j_i])
                    T.writes(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i])
                    res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i] = res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i] + T.Cast("int32", a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_k_o * 4 + v_k_i]) * T.Cast("int32", b_in_local_spad_w[v_k_o * 4 + v_k_i, v_j_o * 4 + v_j_i])
    Does not match the tensorize description:
    for io, k, j, i in T.grid(16, 4, 4, 4):
        with T.block(""):
            vio, vii, vjj, vkk = T.axis.remap("SSSR", [io, i, j, k])
            C = T.Buffer((64, 4), "int32", scope="local.acc", offset_factor=1)
            A = T.Buffer((64, 4), "int8", scope="local.spad", offset_factor=1)
            B = T.Buffer((4, 4), "int8", scope="local.spad_w", offset_factor=1)
            T.reads(C[vii + vio, vjj + vio], A[vii + vio, vkk], B[vkk, vjj + vio])
            T.writes(C[vii + vio, vjj + vio])
            C[vii + vio, vjj + vio] = C[vii + vio, vjj + vio] + T.Cast("int32", A[vii + vio, vkk]) * T.Cast("int32", B[vkk, vjj + vio])

And the intrin is described like this:

@T.prim_func

    def my_matmul_desc(a: T.handle, b:T.handle, c:T.handle, ) -> None:
        A = T.match_buffer(a, (dim_io, dim_k), "int8", offset_factor=1, scope="local.spad")
        B = T.match_buffer(b, (dim_ko, dim_j), "int8", offset_factor=1, scope="local.spad_w")
        C = T.match_buffer(c, (dim_io, dim_j), "int32", offset_factor=1,  scope="local.acc")

        with T.block("root"):
            T.reads(C[0:dim_io, 0:dim_j],
                    A[0:dim_io, 0:dim_k],
                    B[0:dim_ko, 0:dim_j])
            T.writes(C[0:dim_io, 0:dim_j])

            for io in T.serial(0, io_extent):
                for k, j, i in T.grid(dim_k, dim_j, dim_i):
                    with T.block(""):
                        vio, vii, vjj, vkk = T.axis.remap("SSSR", [io, i, j, k])
                        C[vii+vio*dim_i, vjj] = C[vii+vio*dim_i, vjj] + T.cast(A[vii+vio*dim_i, vkk], "int32") * T.cast(B[vkk, vjj], "int32")

What am I doing wrong here in the description of the intrin? I noticed that some variables like v_i_o_o are automatically inserted, but I think they are always zero.

I think the main issue is that I don’t quite get how I need to descibe the buffer matching part.

@Hzfengsy Do you maybe have an idea for a better way to solve this? I just want to access the loop variable, and copying the loop into the intrin doesn’t feel like the right way to solve this.