[TIR] tir schedule compute_at will break the block split steps, its different from te schedule, is that what we expect?

# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 512, 7, 7), "float32"), kernel: T.Buffer((512, 512, 3, 3), "float32"), bias: T.Buffer((1, 512, 1, 1), "float32"), compute: T.Buffer((1, 512, 7, 7), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((1, 512, 9, 9))
        conv2d_nchw = T.alloc_buffer((1, 512, 7, 7))
        for i0, i1, i2, i3 in T.grid(1, 512, 9, 9):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(data[v_i0, v_i1, v_i2 - 1, v_i3 - 1])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i2 and v_i2 < 8 and 1 <= v_i3 and v_i3 < 8, data[v_i0, v_i1, v_i2 - 1, v_i3 - 1], T.float32(0))
        for nn, ff_0, yy, xx, rc_0, rc_1, ry, rx, ff_1 in T.grid(1, 256, 7, 7, 64, 8, 3, 3, 2):
            with T.block("conv2d_nchw"):
                v_nn = T.axis.spatial(1, nn)
                v_ff = T.axis.spatial(512, ff_0 * 2 + ff_1)
                v_yy, v_xx = T.axis.remap("SS", [yy, xx])
                v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1)
                v_ry, v_rx = T.axis.remap("RR", [ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], kernel[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * kernel[v_ff, v_rc, v_ry, v_rx]
        for i0, i1_0, i2, i3, i1_1 in T.grid(1, 256, 7, 7, 2):
            with T.block("compute"):
                v_i0 = T.axis.spatial(1, i0)
                v_i1 = T.axis.spatial(512, i1_0 * 2 + i1_1)
                v_i2, v_i3 = T.axis.remap("SS", [i2, i3])
                T.reads(conv2d_nchw[v_i0, v_i1, v_i2, v_i3], bias[v_i0, v_i1, 0, 0])
                T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                compute[v_i0, v_i1, v_i2, v_i3] = T.max(conv2d_nchw[v_i0, v_i1, v_i2, v_i3] + bias[v_i0, v_i1, 0, 0], T.float32(0))

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="pad_temp", func_name="main")
  l1, l2, l3, l4 = sch.get_loops(block=b0)
  b5 = sch.get_block(name="conv2d_nchw", func_name="main")
  l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b5)
  b13 = sch.get_block(name="T_add", func_name="main")
  l14, l15, l16, l17 = sch.get_loops(block=b13)
  b18 = sch.get_block(name="compute", func_name="main")
  l19, l20, l21, l22 = sch.get_loops(block=b18)
  b23 = sch.get_block(name="T_add", func_name="main")
  sch.compute_inline(block=b23)
  l24, l25 = sch.split(loop=l7, factors=[None, 2], preserve_unit_iters=True)
  l26, l27 = sch.split(loop=l10, factors=[None, 8], preserve_unit_iters=True)
  l28, l29, l30, l31, l32, l33, l34, l35, l36 = sch.get_loops(block=b5)
  sch.reorder(l28, l29, l31, l32, l33, l34, l35, l36, l30)
  l37, l38 = sch.split(loop=l20, factors=[None, 2], preserve_unit_iters=True)
  l39, l40, l41, l42, l43 = sch.get_loops(block=b18)
  sch.reorder(l39, l40, l42, l43, l41)>

and add compute_at primative

class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 512, 7, 7), "float32"), kernel: T.Buffer((512, 512, 3, 3), "float32"), bias: T.Buffer((1, 512, 1, 1), "float32"), compute: T.Buffer((1, 512, 7, 7), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((1, 512, 9, 9))
        conv2d_nchw = T.alloc_buffer((1, 512, 7, 7))
        for i0, i1, i2, i3 in T.grid(1, 512, 9, 9):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(data[v_i0, v_i1, v_i2 - 1, v_i3 - 1])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i2 and v_i2 < 8 and 1 <= v_i3 and v_i3 < 8, data[v_i0, v_i1, v_i2 - 1, v_i3 - 1], T.float32(0))
        for i0, i1_0, i2, i3 in T.grid(1, 256, 7, 7):
            for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 2, 1, 1, 512, 3, 3):
                with T.block("conv2d_nchw"):
                    v_nn = T.axis.spatial(1, ax0)
                    v_ff = T.axis.spatial(512, i1_0 * 2 + ax1)
                    v_yy = T.axis.spatial(7, i2 + ax2)
                    v_xx = T.axis.spatial(7, i3 + ax3)
                    v_rc, v_ry, v_rx = T.axis.remap("RRR", [ax4, ax5, ax6])
                    T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], kernel[v_ff, v_rc, v_ry, v_rx])
                    T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                    with T.init():
                        conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * kernel[v_ff, v_rc, v_ry, v_rx]
            for i1_1 in range(2):
                with T.block("compute"):
                    v_i0 = T.axis.spatial(1, i0)
                    v_i1 = T.axis.spatial(512, i1_0 * 2 + i1_1)
                    v_i2, v_i3 = T.axis.remap("SS", [i2, i3])
                    T.reads(conv2d_nchw[v_i0, v_i1, v_i2, v_i3], bias[v_i0, v_i1, 0, 0])
                    T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                    compute[v_i0, v_i1, v_i2, v_i3] = T.max(conv2d_nchw[v_i0, v_i1, v_i2, v_i3] + bias[v_i0, v_i1, 0, 0], T.float32(0))

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="pad_temp", func_name="main")
  l1, l2, l3, l4 = sch.get_loops(block=b0)
  b5 = sch.get_block(name="conv2d_nchw", func_name="main")
  l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b5)
  b13 = sch.get_block(name="T_add", func_name="main")
  l14, l15, l16, l17 = sch.get_loops(block=b13)
  b18 = sch.get_block(name="compute", func_name="main")
  l19, l20, l21, l22 = sch.get_loops(block=b18)
  b23 = sch.get_block(name="T_add", func_name="main")
  sch.compute_inline(block=b23)
  l24, l25 = sch.split(loop=l7, factors=[None, 2], preserve_unit_iters=True)
  l26, l27 = sch.split(loop=l10, factors=[None, 8], preserve_unit_iters=True)
  l28, l29, l30, l31, l32, l33, l34, l35, l36 = sch.get_loops(block=b5)
  sch.reorder(l28, l29, l31, l32, l33, l34, l35, l36, l30)
  l37, l38 = sch.split(loop=l20, factors=[None, 2], preserve_unit_iters=True)
  l39, l40, l41, l42, l43 = sch.get_loops(block=b18)
  sch.reorder(l39, l40, l42, l43, l41)
  sch.compute_at(block=b5, loop=l43, preserve_unit_loops=True, index=-1)

The loop split will not be reserved, it different from te.schedule.

  1. In te, we would get (64, 8, 3, 3, 2) in accumulate block
  2. If I set preserve_unit_loops, the Spacial axe will follow the unit loops, which means there is a implicit reorder, like (1, 2, 1, 1, 512, 3, 3) should be (1, 1, 1, 1, 512, 3, 3, 2)?

Is that what we want?

It is expected. Please do compute_at before spliting

Thanks for your reply! BTW, why we cannot just keep the split stuffs, is there any unhandleable case?

TE Schedule uses lazy lowering, it can remember the split operation wherever it is called. However, TIR schedule is in eager mode, which applies the transformations immediately.

For the specific case of split + compute_at, the split size is based on the compute_at region, so it should be APPLIED after the compute_at

Got it , thx a lot :grinning: