I am a TensorIR/TVMScript beginner, I did an experiment, triggered some errors, how to deal with it? Thanks.
@T.prim_func
def compute_at_call_extern(a: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i in range(128):
with T.block("B"):
vi = T.axis.spatial(128, i)
T.reads([A[0:128,0:128]])
T.writes([B[vi, 0:128]])
T.evaluate(T.call_extern("test_cust_mul",
T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data, 0, 0, 1, dtype="handle"),
T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data, vi*128, 0, 1, dtype="handle"), dtype="handle"))
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
def test_compute_at2():
sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
print(sch.mod.script())
block_b = sch.get_block("B")
oc, ic = sch.get_loops(sch.get_block("C"))
sch.compute_at(block_b, oc, preserve_unit_loops=True) # failed here
print(sch.mod.script())
Check failed: (false) is false: ValueError: BufferRegion pattern match failed
This is what I was hoping for:
for(i=0; i<128; i++) {
test_cust_mul(A, B+i*128);
for(j=0; j<128; j++) {
C[i][j] = B[i][j] + 1.0;
}
}
Thank you for your attention and reply!