PadInput_i0, PadInput_i1, PadInput_i2, PadInput_i3 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis)
T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
compute_i, compute_j = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, compute_ry, compute_rx = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
compute_red_ax0, compute_red_ax1, compute_red_ax2, compute_red_k1 = tuple(compute_red.op.axis) + tuple(compute_red.op.reduce_axis)
T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
compute_local, = s.cache_write([compute], "local")
compute_local_nn_c, compute_local_ff_c, compute_local_yy_c, compute_local_xx_c, compute_local_rc, compute_local_ry, compute_local_rx = tuple(compute_local.op.axis) + tuple(compute_local.op.reduce_axis)
compute_local_nn_c_o_i, compute_local_nn_c_i = s[compute_local].split(compute_local_nn_c, factor=3)
compute_local_nn_c_o_o_i, compute_local_nn_c_o_i = s[compute_local].split(compute_local_nn_c_o_i, factor=1)
compute_local_nn_c_o_o_o_i, compute_local_nn_c_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_i, factor=1)
compute_local_nn_c_o_o_o_o, compute_local_nn_c_o_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_o_i, factor=1)
compute_local_ff_c_o_i, compute_local_ff_c_i = s[compute_local].split(compute_local_ff_c, factor=1)
compute_local_ff_c_o_o_i, compute_local_ff_c_o_i = s[compute_local].split(compute_local_ff_c_o_i, factor=1)
compute_local_ff_c_o_o_o_i, compute_local_ff_c_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_i, factor=25)
compute_local_ff_c_o_o_o_o, compute_local_ff_c_o_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_o_i, factor=1)
compute_local_yy_c_o_i, compute_local_yy_c_i = s[compute_local].split(compute_local_yy_c, factor=1)
compute_local_yy_c_o_o_i, compute_local_yy_c_o_i = s[compute_local].split(compute_local_yy_c_o_i, factor=1)
compute_local_yy_c_o_o_o_i, compute_local_yy_c_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_i, factor=1)
compute_local_yy_c_o_o_o_o, compute_local_yy_c_o_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_o_i, factor=5)
compute_local_xx_c_o_i, compute_local_xx_c_i = s[compute_local].split(compute_local_xx_c, factor=1)
compute_local_xx_c_o_o_i, compute_local_xx_c_o_i = s[compute_local].split(compute_local_xx_c_o_i, factor=1)
compute_local_xx_c_o_o_o_i, compute_local_xx_c_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_i, factor=5)
compute_local_xx_c_o_o_o_o, compute_local_xx_c_o_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_o_i, factor=1)
compute_local_rc_o_i, compute_local_rc_i = s[compute_local].split(compute_local_rc, factor=1)
compute_local_rc_o_o, compute_local_rc_o_i = s[compute_local].split(compute_local_rc_o_i, factor=1)
compute_local_ry_o_i, compute_local_ry_i = s[compute_local].split(compute_local_ry, factor=5)
compute_local_ry_o_o, compute_local_ry_o_i = s[compute_local].split(compute_local_ry_o_i, factor=1)
compute_local_rx_o_i, compute_local_rx_i = s[compute_local].split(compute_local_rx, factor=1)
compute_local_rx_o_o, compute_local_rx_o_i = s[compute_local].split(compute_local_rx_o_i, factor=1)
s[compute_local].reorder(compute_local_nn_c_o_o_o_o, compute_local_ff_c_o_o_o_o, compute_local_yy_c_o_o_o_o, compute_local_xx_c_o_o_o_o, compute_local_nn_c_o_o_o_i, compute_local_ff_c_o_o_o_i, compute_local_yy_c_o_o_o_i, compute_local_xx_c_o_o_o_i, compute_local_nn_c_o_o_i, compute_local_ff_c_o_o_i, compute_local_yy_c_o_o_i, compute_local_xx_c_o_o_i, compute_local_rc_o_o, compute_local_ry_o_o, compute_local_rx_o_o, compute_local_rc_o_i, compute_local_ry_o_i, compute_local_rx_o_i, compute_local_nn_c_o_i, compute_local_ff_c_o_i, compute_local_yy_c_o_i, compute_local_xx_c_o_i, compute_local_rc_i, compute_local_ry_i, compute_local_rx_i, compute_local_nn_c_i, compute_local_ff_c_i, compute_local_yy_c_i, compute_local_xx_c_i)
compute_nn_o_i, compute_nn_i = s[compute].split(compute_nn, factor=3)
compute_nn_o_o_i, compute_nn_o_i = s[compute].split(compute_nn_o_i, factor=1)
compute_nn_o_o_o, compute_nn_o_o_i = s[compute].split(compute_nn_o_o_i, factor=1)
compute_ff_o_i, compute_ff_i = s[compute].split(compute_ff, factor=1)
compute_ff_o_o_i, compute_ff_o_i = s[compute].split(compute_ff_o_i, factor=25)
compute_ff_o_o_o, compute_ff_o_o_i = s[compute].split(compute_ff_o_o_i, factor=1)
compute_yy_o_i, compute_yy_i = s[compute].split(compute_yy, factor=1)
compute_yy_o_o_i, compute_yy_o_i = s[compute].split(compute_yy_o_i, factor=1)
compute_yy_o_o_o, compute_yy_o_o_i = s[compute].split(compute_yy_o_o_i, factor=5)
compute_xx_o_i, compute_xx_i = s[compute].split(compute_xx, factor=1)
compute_xx_o_o_i, compute_xx_o_i = s[compute].split(compute_xx_o_i, factor=5)
compute_xx_o_o_o, compute_xx_o_o_i = s[compute].split(compute_xx_o_o_i, factor=1)
s[compute].reorder(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o, compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i, compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i, compute_nn_i, compute_ff_i, compute_yy_i, compute_xx_i)
s[compute_local].compute_at(s[compute], compute_xx_o_i)
T_reshape_shared = s.cache_read(T_reshape, "shared", [compute_local])
T_reshape_shared_ax0, T_reshape_shared_ax1, T_reshape_shared_ax2, T_reshape_shared_ax3 = tuple(T_reshape_shared.op.axis)
s[T_reshape_shared].compute_at(s[compute_local], compute_local_rx_o_o)
s[T_reshape].compute_inline()
s[compute].compute_inline()
pad_temp_shared = s.cache_read(pad_temp, "shared", [compute_local])
pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
s[pad_temp_shared].compute_at(s[compute_local], compute_local_rx_o_o)
s[pad_temp].compute_inline()
s[T_reshape].compute_inline()
s[PadInput].compute_inline()
T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused = s[T_reshape].fuse(T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3)
T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[T_reshape].split(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused, factor=32)
s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, te.thread_axis("blockIdx.x"))
s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i, te.thread_axis("threadIdx.x"))
compute_red_ax0_ax1_fused_ax2_fused = s[compute_red].fuse(compute_red_ax0, compute_red_ax1, compute_red_ax2)
compute_red_ax0_ax1_fused_ax2_fused_o, compute_red_ax0_ax1_fused_ax2_fused_i = s[compute_red].split(compute_red_ax0_ax1_fused_ax2_fused, factor=64)
s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_o, te.thread_axis("blockIdx.x"))
s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_i, te.thread_axis("threadIdx.x"))
compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused = s[compute].fuse(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o)
s[compute].bind(compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused, te.thread_axis("blockIdx.x"))
compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused = s[compute].fuse(compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i)
s[compute].bind(compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused, te.thread_axis("vthread"))
compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused = s[compute].fuse(compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i)
s[compute].bind(compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused, te.thread_axis("threadIdx.x"))
T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[T_reshape_shared].fuse(T_reshape_shared_ax0, T_reshape_shared_ax1, T_reshape_shared_ax2, T_reshape_shared_ax3)
T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
s[T_reshape_shared].vectorize(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=125)
s[T_reshape_shared].bind(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=125)
s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "auto_unroll_max_step", 512)
s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "unroll_explicit", True)
s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, "auto_unroll_max_step", 64)
s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, "unroll_explicit", True)