A failed example of using compute_at based on TVMScript

I am a TensorIR/TVMScript beginner, I did an experiment, triggered some errors, how to deal with it? Thanks.

@T.prim_func
def compute_at_call_extern(a: T.handle, c: T.handle) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    A = T.match_buffer(a, (128, 128), "float32")
    B = T.alloc_buffer((128, 128), "float32")
    C = T.match_buffer(c, (128, 128), "float32")
    for i in range(128):
        with T.block("B"):
            vi = T.axis.spatial(128, i)
            T.reads([A[0:128,0:128]])
            T.writes([B[vi, 0:128]])
            T.evaluate(T.call_extern("test_cust_mul", 
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data, 0, 0, 1, dtype="handle"),
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data, vi*128, 0, 1, dtype="handle"), dtype="handle"))

    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0
        
def test_compute_at2():
    sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
    print(sch.mod.script())
    block_b = sch.get_block("B")
    oc, ic = sch.get_loops(sch.get_block("C"))
    sch.compute_at(block_b, oc, preserve_unit_loops=True)  # failed here
    print(sch.mod.script())

Check failed: (false) is false: ValueError: BufferRegion pattern match failed

This is what I was hoping for:

for(i=0; i<128; i++) {
    test_cust_mul(A, B+i*128);
    for(j=0; j<128; j++) {
       C[i][j] = B[i][j] + 1.0;
    }
}

Thank you for your attention and reply!

We don’t support opaque access to buffers

I see. I feel like your intention is to move block “C” under certain loops above block “B”. Is that correct. If so, you may use reverse_compute_at in your particular case.

try this:

def test_compute_at2():
    sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
    print(sch.mod.script())
    block_b = sch.get_block("B")
    block_c = sch.get_block("C")
    loop_i, = sch.get_loops(block_b)
    sch.reverse_compute_at(block_c, loop_i, preserve_unit_loops=True)  # failed here
    print(sch.mod.script())

@junrushao1994

Thank you for your reply, this is the way I want.

According to your suggestion, I did an experiment and changed the constant 128 to the variable OH,OW.

@T.prim_func
def compute_at_call_extern(a: T.handle, c: T.handle) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    OH = T.var("int32")
    OW = T.var("int32")
    A = T.match_buffer(a, (OH, OW), "float32")
    B = T.alloc_buffer((OH, OW), "float32")
    C = T.match_buffer(c, (OH, OW), "float32")
    for i in range(OH):
        with T.block("B"):
            vi = T.axis.spatial(OH, i)
            T.reads([A[0:OH,0:OW]])
            T.writes([B[vi, 0:OW]])
            T.evaluate(T.call_extern("test_cust_mul", 
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data, 0, 0, 1, dtype="handle"),
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data, vi*128, 0, 1, dtype="handle"), dtype="handle"))

    for i, j in T.grid(OH, OW):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = B[vi, vj] + 1.0

def test_compute_at2():
    sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
    print(sch.mod.script())
    block_b = sch.get_block("B")
    loop_i, = sch.get_loops(block_b)
    block_c = sch.get_block("C")
    sch.reverse_compute_at(block_c, loop_i, preserve_unit_loops=True)  
    print(sch.mod.script())

The printed result is:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(a: T.handle, c: T.handle) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        OH = T.var("int32")
        OW = T.var("int32")
        A = T.match_buffer(a, [OH, OW], dtype="float32")
        C = T.match_buffer(c, [OH, OW], dtype="float32")
        # body
        # with T.block("root")
        B = T.alloc_buffer([OH, OW], dtype="float32")
        for i in T.serial(0, OH):
            with T.block("B"):
                vi = T.axis.spatial(OH, i)
                T.reads([A[0 : OH, 0 : OW]])
                T.writes([B[vi, 0 : OW]])
                T.evaluate(T.call_extern("test_cust_mul", T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 0, 1, dtype="handle"), T.tvm_access_ptr(T.type_annotation(dtype="float32"), B.data, vi * 128, 0, 1, dtype="handle"), dtype="handle"))
            for ax0, ax1 in T.grid(T.min(1, OH - i), OW):  
                with T.block("C"):
                    vi = T.axis.spatial(OH, i + ax0)
                    vj = T.axis.spatial(OW, ax1)
                    T.reads([B[vi, vj]])
                    T.writes([C[vi, vj]])
                    C[vi, vj] = B[vi, vj] + T.float32(1)

I have a new problem, for ax0, ax1 in T.grid(T.min(1, OH - i), OW):

Rather than for j in T.serial(0: OW):

How can I deal with this problem?

Thank you very much for your reply!

The analyzer doesn’t seem to be able to prove the following equality:

min(1, OH - i) = 1

because the range of OH is not known

But it seems that we should not consider the min(1, OH - i) = 1, but directly:

for(i=0; i<OH; i++) {
    test_cust_mul(A, B+i*OW);
    for(j=0; j<OW; j++) {
       C[i][j] = B[i][j] + 1.0;
    }
}

Let me explain how compute-at / reverse-compute-at works.

Basically, we do integer set analysis to determine the loops domain of the block being moved (in our example it’s block C). In our particular case, the domain inferred is:

  [0 : T.min(1, OH - i), 0 : OW]
= [0 : 1, 0 : OW]    (given `i < OH`)

which is a 2-dimensional rectangle whose 1st dim is 1 and the 2nd dim is OW.

The flag preserve_unit_loops controls if the system should generate the 1st loop - it could be removed given it has a unit extent.

A more profound issue we may encounter for those symbolic shape cases: we need a mechanism to embed domains of variables in the IR? CC: @tqchen