# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(data: T.Buffer((1, 512, 7, 7), "float32"), kernel: T.Buffer((512, 512, 3, 3), "float32"), bias: T.Buffer((1, 512, 1, 1), "float32"), compute: T.Buffer((1, 512, 7, 7), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((1, 512, 9, 9))
conv2d_nchw = T.alloc_buffer((1, 512, 7, 7))
for i0, i1, i2, i3 in T.grid(1, 512, 9, 9):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[v_i0, v_i1, v_i2 - 1, v_i3 - 1])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i2 and v_i2 < 8 and 1 <= v_i3 and v_i3 < 8, data[v_i0, v_i1, v_i2 - 1, v_i3 - 1], T.float32(0))
for nn, ff_0, yy, xx, rc_0, rc_1, ry, rx, ff_1 in T.grid(1, 256, 7, 7, 64, 8, 3, 3, 2):
with T.block("conv2d_nchw"):
v_nn = T.axis.spatial(1, nn)
v_ff = T.axis.spatial(512, ff_0 * 2 + ff_1)
v_yy, v_xx = T.axis.remap("SS", [yy, xx])
v_rc = T.axis.reduce(512, rc_0 * 8 + rc_1)
v_ry, v_rx = T.axis.remap("RR", [ry, rx])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], kernel[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * kernel[v_ff, v_rc, v_ry, v_rx]
for i0, i1_0, i2, i3, i1_1 in T.grid(1, 256, 7, 7, 2):
with T.block("compute"):
v_i0 = T.axis.spatial(1, i0)
v_i1 = T.axis.spatial(512, i1_0 * 2 + i1_1)
v_i2, v_i3 = T.axis.remap("SS", [i2, i3])
T.reads(conv2d_nchw[v_i0, v_i1, v_i2, v_i3], bias[v_i0, v_i1, 0, 0])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.max(conv2d_nchw[v_i0, v_i1, v_i2, v_i3] + bias[v_i0, v_i1, 0, 0], T.float32(0))
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="pad_temp", func_name="main")
l1, l2, l3, l4 = sch.get_loops(block=b0)
b5 = sch.get_block(name="conv2d_nchw", func_name="main")
l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b5)
b13 = sch.get_block(name="T_add", func_name="main")
l14, l15, l16, l17 = sch.get_loops(block=b13)
b18 = sch.get_block(name="compute", func_name="main")
l19, l20, l21, l22 = sch.get_loops(block=b18)
b23 = sch.get_block(name="T_add", func_name="main")
sch.compute_inline(block=b23)
l24, l25 = sch.split(loop=l7, factors=[None, 2], preserve_unit_iters=True)
l26, l27 = sch.split(loop=l10, factors=[None, 8], preserve_unit_iters=True)
l28, l29, l30, l31, l32, l33, l34, l35, l36 = sch.get_loops(block=b5)
sch.reorder(l28, l29, l31, l32, l33, l34, l35, l36, l30)
l37, l38 = sch.split(loop=l20, factors=[None, 2], preserve_unit_iters=True)
l39, l40, l41, l42, l43 = sch.get_loops(block=b18)
sch.reorder(l39, l40, l42, l43, l41)>
and add compute_at primative
class Module:
@T.prim_func
def main(data: T.Buffer((1, 512, 7, 7), "float32"), kernel: T.Buffer((512, 512, 3, 3), "float32"), bias: T.Buffer((1, 512, 1, 1), "float32"), compute: T.Buffer((1, 512, 7, 7), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((1, 512, 9, 9))
conv2d_nchw = T.alloc_buffer((1, 512, 7, 7))
for i0, i1, i2, i3 in T.grid(1, 512, 9, 9):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[v_i0, v_i1, v_i2 - 1, v_i3 - 1])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i2 and v_i2 < 8 and 1 <= v_i3 and v_i3 < 8, data[v_i0, v_i1, v_i2 - 1, v_i3 - 1], T.float32(0))
for i0, i1_0, i2, i3 in T.grid(1, 256, 7, 7):
for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 2, 1, 1, 512, 3, 3):
with T.block("conv2d_nchw"):
v_nn = T.axis.spatial(1, ax0)
v_ff = T.axis.spatial(512, i1_0 * 2 + ax1)
v_yy = T.axis.spatial(7, i2 + ax2)
v_xx = T.axis.spatial(7, i3 + ax3)
v_rc, v_ry, v_rx = T.axis.remap("RRR", [ax4, ax5, ax6])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], kernel[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * kernel[v_ff, v_rc, v_ry, v_rx]
for i1_1 in range(2):
with T.block("compute"):
v_i0 = T.axis.spatial(1, i0)
v_i1 = T.axis.spatial(512, i1_0 * 2 + i1_1)
v_i2, v_i3 = T.axis.remap("SS", [i2, i3])
T.reads(conv2d_nchw[v_i0, v_i1, v_i2, v_i3], bias[v_i0, v_i1, 0, 0])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.max(conv2d_nchw[v_i0, v_i1, v_i2, v_i3] + bias[v_i0, v_i1, 0, 0], T.float32(0))
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="pad_temp", func_name="main")
l1, l2, l3, l4 = sch.get_loops(block=b0)
b5 = sch.get_block(name="conv2d_nchw", func_name="main")
l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b5)
b13 = sch.get_block(name="T_add", func_name="main")
l14, l15, l16, l17 = sch.get_loops(block=b13)
b18 = sch.get_block(name="compute", func_name="main")
l19, l20, l21, l22 = sch.get_loops(block=b18)
b23 = sch.get_block(name="T_add", func_name="main")
sch.compute_inline(block=b23)
l24, l25 = sch.split(loop=l7, factors=[None, 2], preserve_unit_iters=True)
l26, l27 = sch.split(loop=l10, factors=[None, 8], preserve_unit_iters=True)
l28, l29, l30, l31, l32, l33, l34, l35, l36 = sch.get_loops(block=b5)
sch.reorder(l28, l29, l31, l32, l33, l34, l35, l36, l30)
l37, l38 = sch.split(loop=l20, factors=[None, 2], preserve_unit_iters=True)
l39, l40, l41, l42, l43 = sch.get_loops(block=b18)
sch.reorder(l39, l40, l42, l43, l41)
sch.compute_at(block=b5, loop=l43, preserve_unit_loops=True, index=-1)
The loop split will not be reserved, it different from te.schedule.
- In te, we would get (64, 8, 3, 3, 2) in accumulate block
- If I set preserve_unit_loops, the Spacial axe will follow the unit loops, which means there is a implicit reorder, like (1, 2, 1, 1, 512, 3, 3) should be (1, 1, 1, 1, 512, 3, 3, 2)?
Is that what we want?