Hi Everyone,
I am trying to inline the loop-nest (with if-then-else condition) using compute-inline scheduling primitive. Here is the example prim_func,
@T.prim_func(private=True) def nested_if_example(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): B = T.alloc_buffer((2, 16), "float32") for axis0, axis1 in T.grid(2, 16): with T.block("A"): v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1]) T.reads(A[v_axis0, v_axis1]) T.writes(B[v_axis0, v_axis1]) B[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)) for axis0, axis1 in T.grid(2, 16): with T.block("B"): v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1]) T.reads(B[v_axis0, v_axis1]) T.writes(C[v_axis0, v_axis1]) C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, B[v_axis0, v_axis1] * B[v_axis0, v_axis1], T.float32(0))
After applying compute_inline, the code looks like below,
sch = tvm.tir.Schedule(nested_if_example) blocks = sch.get_child_blocks(sch.get_block("root")) sch.compute_inline(blocks[0]) @T.prim_func(private=True) def after_compute_inline(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for axis0, axis1 in T.grid(2, 16): with T.block("B"): v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1]) T.reads(A[v_axis0, v_axis1]) T.writes(C[v_axis0, v_axis1]) C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)) * T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)), T.float32(0))
Ever after applying reverse_compute_inline, the nested if then else statements are appearing,
sch = tvm.tir.Schedule(nested_if_example) blocks = sch.get_child_blocks(sch.get_block("root")) sch.reverse_compute_inline(blocks[1]) @T.prim_func(private=True) def after_reverse_compute_inline(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for axis0, axis1 in T.grid(2, 16): with T.block("A"): v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1]) T.reads(A[v_axis0, v_axis1]) T.writes(C[v_axis0, v_axis1]) C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)) * T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)), T.float32(0))
The expected code should look like this,
@T.prim_func(private=True) def expected(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for axis0, axis1 in T.grid(2, 16): with T.block("A"): v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1]) T.reads(A[v_axis0, v_axis1]) T.writes(C[v_axis0, v_axis1]) C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2) * A[v_axis0, v_axis1] * T.float32(2), T.float32(0))
Question: What is the right way to transform nested_if_example to the expected prim_func?
Thanks in advance,
Rahul