How to optimize common load in tir?

Can we transform the following codes:

with T.attr(0, "compute_scope", "main_compute_"):
        for ax0, ax1, ax2, ax3_0 in T.grid(4, 2, 2, 2):
            cse_var_3: T.int32 = ax3_0 * 4
            cse_var_2: T.int32 = ax0 * 32 + ax1 * 16 + ax2 * 8 + cse_var_3
            cse_var_1: T.int32 = ax0 * 200 + ax1 * 80 + ax2 * 16 + cse_var_3
            pool_max_2 = T.Buffer((128,), data=pool_max)
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.Broadcast(T.float32(-3.4028234663852886e+38), 4)
            A_2 = T.Buffer((800,), data=A)
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1:cse_var_1 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 8:cse_var_1 + 8 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 16:cse_var_1 + 16 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 40:cse_var_1 + 40 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 48:cse_var_1 + 48 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 56:cse_var_1 + 56 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 80:cse_var_1 + 80 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 88:cse_var_1 + 88 + 4])
            pool_max_2[cse_var_2:cse_var_2 + 4] = T.max(pool_max_2[cse_var_2:cse_var_2 + 4], A_2[cse_var_1 + 96:cse_var_1 + 96 + 4])

into:

with T.attr(0, "compute_scope", "main_compute_"):
  for ax0, ax1, ax2, ax3_0 in T.grid(4, 2, 2, 2):
    cse_var_3: T.int32 = ax3_0 * 4
    cse_var_2: T.int32 = ax0 * 32 + ax1 * 16 + ax2 * 8 + cse_var_3
    cse_var_1: T.int32 = ax0 * 200 + ax1 * 80 + ax2 * 16 + cse_var_3
    pool_max_2 = T.Buffer((128,), data=pool_max)
    pool_max_2[cse_var_2:cse_var_2 + 4] = T.Broadcast(T.float32(-3.4028234663852886e+38), 4)
    A_2 = T.Buffer((800,), data=A)
    tmp = pool_max_2[cse_var_2:cse_var_2 + 4]
    tmp = T.max(tmp, A_2[cse_var_1:cse_var_1 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 8:cse_var_1 + 8 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 16:cse_var_1 + 16 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 40:cse_var_1 + 40 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 48:cse_var_1 + 48 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 56:cse_var_1 + 56 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 80:cse_var_1 + 80 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 88:cse_var_1 + 88 + 4])
    tmp = T.max(tmp, A_2[cse_var_1 + 96:cse_var_1 + 96 + 4])

This avoids duplicate loads for pool_max_2[cse_var_2:cse_var_2 + 4]

Setting pool_max_2 scope to “local” should be the same :slight_smile: