[TensorIR] if_then_else simplification when using compute_inline

Hi Everyone,

I am trying to inline the loop-nest (with if-then-else condition) using compute-inline scheduling primitive. Here is the example prim_func,

@T.prim_func(private=True)
def nested_if_example(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    B = T.alloc_buffer((2, 16), "float32")
    for axis0, axis1 in T.grid(2, 16):
        with T.block("A"):
            v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1])
            T.reads(A[v_axis0, v_axis1])
            T.writes(B[v_axis0, v_axis1])
            B[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0))

    for axis0, axis1 in T.grid(2, 16):
        with T.block("B"):
            v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1])
            T.reads(B[v_axis0, v_axis1])
            T.writes(C[v_axis0, v_axis1])
            C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, B[v_axis0, v_axis1] * B[v_axis0, v_axis1], T.float32(0))

After applying compute_inline, the code looks like below,

sch = tvm.tir.Schedule(nested_if_example)
blocks = sch.get_child_blocks(sch.get_block("root"))
sch.compute_inline(blocks[0])

@T.prim_func(private=True)
def after_compute_inline(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for axis0, axis1 in T.grid(2, 16):
        with T.block("B"):
            v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1])
            T.reads(A[v_axis0, v_axis1])
            T.writes(C[v_axis0, v_axis1])
            C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)) * T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)), T.float32(0))

Ever after applying reverse_compute_inline, the nested if then else statements are appearing,

sch = tvm.tir.Schedule(nested_if_example)
blocks = sch.get_child_blocks(sch.get_block("root"))
sch.reverse_compute_inline(blocks[1])

@T.prim_func(private=True)
def after_reverse_compute_inline(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for axis0, axis1 in T.grid(2, 16):
        with T.block("A"):
            v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1])
            T.reads(A[v_axis0, v_axis1])
            T.writes(C[v_axis0, v_axis1])
            C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)) * T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2), T.float32(0)), T.float32(0))

The expected code should look like this,

@T.prim_func(private=True)
def expected(A: T.Buffer((2, 14), "float32"), C: T.Buffer((2, 16), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for axis0, axis1 in T.grid(2, 16):
        with T.block("A"):
            v_axis0, v_axis1 = T.axis.remap("SS", [axis0, axis1])
            T.reads(A[v_axis0, v_axis1])
            T.writes(C[v_axis0, v_axis1])
            C[v_axis0, v_axis1] = T.if_then_else(v_axis1 < 14, A[v_axis0, v_axis1] * T.float32(2) * A[v_axis0, v_axis1] * T.float32(2), T.float32(0))

Question: What is the right way to transform nested_if_example to the expected prim_func?

Thanks in advance,

Rahul

CC: @Hzfengsy @tqchen

It’s not the problem of schedule primitive, as the result is somehow correct, but not concise。 Could you please check if arith simplifier can simplify it?

Hi @Hzfengsy

Thanks for the suggestion. After using Simplify pass, nested if-then-else statements are not generated.

I have another similar test case (module1), where simplifier pass is unable to simplify the nested-if-then-else condition.

@I.ir_module
class module1:
    @T.prim_func(private=True)
    def main(A: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "int32"), C: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "uint8")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        B = T.alloc_buffer((T.int64(1), T.int64(2), T.int64(8)), "int32")
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("P"):
                v_axis0, v_axis1, v_axis4 = T.axis.remap("SSS", [axis0, axis1, axis4])
                T.reads(A[v_axis0, v_axis1, v_axis4])
                T.writes(B[v_axis0, v_axis1, v_axis4])
                B[v_axis0, v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(1) and T.int64(6) <= v_axis4, T.int32(0), A[v_axis0, v_axis1, v_axis4])
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("Q"):
                v_axis0, v_axis1, v_axis4 = T.axis.remap("SSS", [axis0, axis1, axis4])
                T.reads(B[v_axis0, v_axis1, v_axis4])
                T.writes(C[v_axis0, v_axis1, v_axis4])
                C[v_axis0, v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(1) and T.int64(6) <= v_axis4, T.uint8(0), T.Cast("uint8", T.max(0, T.min(B[v_axis0, v_axis1, v_axis4], 255))))

On the other hand, upon flipping the statements inside the branch and changing the conditions inside if-then-else and using the tvm.tir.transform.Simplify() pass, nested if-then-else statements are not generated. Below is the example module2 after flipping conditions.

@I.ir_module
class module2:
    @T.prim_func(private=True)
    def main(A: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "int32"), C: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "uint8")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        B = T.alloc_buffer((T.int64(1), T.int64(2), T.int64(8)), "int32")
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("P"):
                v_axis0, v_axis1, v_axis4 = T.axis.remap("SSS", [axis0, axis1, axis4])
                T.reads(A[v_axis0, v_axis1, v_axis4])
                T.writes(B[v_axis0, v_axis1, v_axis4])
                B[v_axis0, v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(0) and T.int64(6) > v_axis4, A[v_axis0, v_axis1, v_axis4], T.int32(0))
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("Q"):
                v_axis0, v_axis1, v_axis4 = T.axis.remap("SSS", [axis0, axis1, axis4])
                T.reads(B[v_axis0, v_axis1, v_axis4])
                T.writes(C[v_axis0, v_axis1, v_axis4])
                C[v_axis0, v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(0) and T.int64(6) > v_axis4, T.Cast("uint8", T.max(0, T.min(B[v_axis0, v_axis1, v_axis4], 255))), T.uint8(0))

Inlined code before flipping:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(A: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "int32"), C: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "uint8")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("P"):
                v_axis0, v_axis1, v_axis4 = T.axis.remap("SSS", [axis0, axis1, axis4])
                T.reads(A[v_axis0, v_axis1, v_axis4])
                T.writes(C[v_axis0, v_axis1, v_axis4])
                C[v_axis0, v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(1) and T.int64(6) <= v_axis4, T.uint8(0), T.Cast("uint8", T.max(0, T.min(T.if_then_else(v_axis1 == T.int64(1) and T.int64(6) <= v_axis4, 0, A[v_axis0, v_axis1, v_axis4]), 255))))

Inlined code after flipping conditions:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(A: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "int32"), C: T.Buffer((T.int64(1), T.int64(2), T.int64(8)), "uint8")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for axis0, axis1, axis4 in T.grid(T.int64(1), T.int64(2), T.int64(8)):
            with T.block("P"):
                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
                v_axis1, v_axis4 = T.axis.remap("SS", [axis1, axis4])
                T.reads(A[T.int64(0), v_axis1, v_axis4])
                T.writes(C[T.int64(0), v_axis1, v_axis4])
                C[T.int64(0), v_axis1, v_axis4] = T.if_then_else(v_axis1 == T.int64(0) and v_axis4 < T.int64(6), T.Cast("uint8", T.max(0, T.min(A[T.int64(0), v_axis1, v_axis4], 255))), T.uint8(0))

Questions:

  1. Is there any specific reason why Simplify pass is unable to simplify the expressions in the presence of (<=) condition?
  2. Is there any external pass(not Simplify pass) which does if-then-else condition flipping in TVM?

Thanks in advance,

Rahul