Redundant condition check

Hi all, I want to do vectorize on local registers, but I found some redundant condition check that make it impossible to vectorize:

The highlighted condition check it obviously unnecessary. Can anyone help me to remove this condition check? Here is my code:

 import tvm
 from tvm import te

 block_x = te.thread_axis("blockIdx.x")
 thread_x = te.thread_axis("threadIdx.x")

 def vectorize(M, target, dtype = 'float32'):
     PM = (M + 4 - 1) // 4 * 4
     A = te.placeholder((M, ), name = "A", dtype = dtype)
     B = te.placeholder((M, ), name = "B", dtype = dtype)

     PA = te.compute((PM, ), lambda i : tvm.tir.if_then_else(tvm.tir.all(i < M), A[i], tvm.tir.const(0.0, 'float32')), name = "PA")
     PB = te.compute((PM, ), lambda i : tvm.tir.if_then_else(tvm.tir.all(i < M), B[i], tvm.tir.const(0.0, 'float32')), name = "PB")

     C = te.compute((M, ), lambda i : PA[i] + PB[i], name = "C")

     s = te.create_schedule(C.op)

     s[PA].compute_inline()
     s[PB].compute_inline()

     lC = s.cache_write(C, "local")
     lA = s.cache_read(PA, "local", [lC])
     lB = s.cache_read(PB, "local", [lC])

     m = C.op.axis[0]

     mo, mi = s[C].split(m, 256)
     mio, mii = s[C].split(mi, 4)

     s[C].bind(mo, block_x)
     s[C].bind(mio, thread_x)

     s[lC].compute_at(s[C], mio)

     s[lA].compute_at(s[C], mio)
     s[lB].compute_at(s[C], mio)

     func = tvm.build(s, [A, B, C], 'cuda')
     print(func.imported_modules[0].get_source())

 if __name__ == '__main__':
     target = tvm.target.Target("cuda", host="c")
     size = 1025
     vectorize(size, target)

Hi~ here is a workaround if you’d like to switch to TensorIR schedule. You can define one more stage to make actual computation length aligned, and leverage reverse_compute_at to keep the aligned computation stage’s loop domain. You could print(s.mod) to show how the program is transformed immediately after each single schedule step.


def vectorize(M, target, dtype = 'float32'):
    PM = (M + 256 - 1) // 256 * 256
    A = te.placeholder((M, ), name = "A", dtype = dtype)
    B = te.placeholder((M, ), name = "B", dtype = dtype)
    PA = te.compute((PM, ), lambda i : tvm.tir.if_then_else(tvm.tir.all(i < M), A[i], tvm.tir.const(0.0, 'float32')), name = "PA")
    PB = te.compute((PM, ), lambda i : tvm.tir.if_then_else(tvm.tir.all(i < M), B[i], tvm.tir.const(0.0, 'float32')), name = "PB")
    C = te.compute((PM, ), lambda i : PA[i] + PB[i], name = "C")
    OutC = te.compute((M, ), lambda i : C[i], name = "OutC")
    func = te.create_prim_func([A, B, OutC])

    s = tvm.tir.Schedule(func)  # s = te.create_schedule(C.op)

    s.set_scope("PA", 0, "local")
    s.set_scope("PB", 0, "local")
    m = s.get_loops("C")[0] # m = C.op.axis[0]
    mo, mi = s.split(m, factors=[None, 256])  # mo, mi = s[C].split(m, 256)
    mio, mii = s.split(mi, factors=[None, 4])  # mio, mii = s[C].split(mi, 4)
    s.bind(mo, "blockIdx.x")  # s[C].bind(mo, block_x)
    s.bind(mio, "threadIdx.x") # s[C].bind(mio, thread_x)
    s.reverse_compute_at("OutC", mio) # s[lC].compute_at(s[C], mio)
    s.compute_at("PA", mio) # s[lA].compute_at(s[C], mio)
    s.compute_at("PB", mio) # s[lB].compute_at(s[C], mio)
    s.vectorize(mii)
    
    func = s.mod["main"]
    print(func.script())
    lowered_func = tvm.lower(func)
    print(lowered_func.script())
    # lib = tvm.build(func, target)
    # func = tvm.build(s, [A, B, C], 'cuda')
    # print(lib.imported_modules[0].get_source())

vectorize(1025, "cuda")

The result could be like:

@T.prim_func
def main(A: T.Buffer((1025,), "float32"), B: T.Buffer((1025,), "float32"), OutC: T.Buffer((1025,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        blockIdx_x = T.env_thread("blockIdx.x")
        T.launch_thread(blockIdx_x, 5)
        PA_local = T.allocate([4], "float32", "local")
        PB_local = T.allocate([4], "float32", "local")
        C = T.allocate([1280], "float32", "global")
        threadIdx_x = T.env_thread("threadIdx.x")
        T.launch_thread(threadIdx_x, 64)
        PA_local_1 = T.Buffer((4,), data=PA_local, scope="local")
        for ax0 in range(4):
            A_1 = T.Buffer((1025,), data=A.data)
            PA_local_1[ax0] = T.if_then_else(blockIdx_x * 256 + threadIdx_x * 4 + ax0 < 1025, A_1[blockIdx_x * 256 + threadIdx_x * 4 + ax0], T.float32(0))
        PB_local_1 = T.Buffer((4,), data=PB_local, scope="local")
        for ax0 in range(4):
            B_1 = T.Buffer((1025,), data=B.data)
            PB_local_1[ax0] = T.if_then_else(blockIdx_x * 256 + threadIdx_x * 4 + ax0 < 1025, B_1[blockIdx_x * 256 + threadIdx_x * 4 + ax0], T.float32(0))
        C_1 = T.Buffer((1280,), data=C)
        C_1[blockIdx_x * 256 + threadIdx_x * 4:blockIdx_x * 256 + threadIdx_x * 4 + 4] = PA_local_1[0:4] + PB_local_1[0:4]
        for ax0 in range(4):
            if blockIdx_x * 256 + threadIdx_x * 4 + ax0 < 1025:
                OutC_1 = T.Buffer((1025,), data=OutC.data)
                OutC_1[blockIdx_x * 256 + threadIdx_x * 4 + ax0] = C_1[blockIdx_x * 256 + threadIdx_x * 4 + ax0]

It works, thanks very much !!