Print auto_schedule Python schedule with topi op

I tried to write a complicate new op and tune it with auto_scheduler.In the description of op, I use the topi.reshape for several times,and I use topi.conv2d as well.

The new OP works well with auto_scheduler, But when I see the printed schedule, I am confused.

    PadInput_i0, PadInput_i1, PadInput_i2, PadInput_i3 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
    compute_i, compute_j = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, compute_ry, compute_rx = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
    compute_red_ax0, compute_red_ax1, compute_red_ax2, compute_red_k1 = tuple(compute_red.op.axis) + tuple(compute_red.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    compute_local, = s.cache_write([compute], "local")
    compute_local_nn_c, compute_local_ff_c, compute_local_yy_c, compute_local_xx_c, compute_local_rc, compute_local_ry, compute_local_rx = tuple(compute_local.op.axis) + tuple(compute_local.op.reduce_axis)
    compute_local_nn_c_o_i, compute_local_nn_c_i = s[compute_local].split(compute_local_nn_c, factor=3)

You can see there are several T_reshape_op, So I cannot distinguish them。So I wanna know can we name the topi op, or is there a way to distinguish the ops. Thank you

This unfortunately cannot be controlled by the auto-scheduler. One good feature of auto-scheduler is that it schedules an entire fused compute graph together. On the other hand, it implies that we don’t differentiate ops once they are fused into a single TE compute. Specifically, the name of the TE tensors, such as PadInputs, T_reshape, and compute, are defined in the TOPI.

Reference:

However, for your particular case, T_reshape seems come from the same op:

T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = \
  tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)

T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = \
  tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)

T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = \
  tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)

So IIUC, this is just the case that auto-scheduler printer wants to make sure the variable (e.g., T_reshape_ax0) it will be referring to is the right one, so it re-generates them using the same reshape op.

Thank you! Your explanation about TE tensors is very clear。

However,I think maybe those T_reshape.op may refer to different ones because I use several topi.reshape, and here’s my code:

def function():
    A = te.placeholder((1, 3, 5, 5), name="A", dtype="float32")
    kernel = te.placeholder((5, 5), name="kernel", dtype="float32")
    max_val = 1e4
    se_h, se_w = kernel.shape
    origin = [se_h // 2, se_w // 2]
    pad_e1 = [0, 0, origin[0], origin[1]]
    pad_e2 = [0, 0, se_h - origin[0] - 1, se_w - origin[1] - 1]
    border_value =  max_val
    output = topi.nn.pad(A, pad_e1, pad_e2, pad_value=border_value)
    print(output.shape)
    neighborhood = te.compute((5, 5), lambda i0, i1: te.if_then_else(kernel[i0, i1] == 0, -max_val, 0), name="neighborhood")
    B, C, H, W = A.shape
    Hpad, Wpad = output.shape[-2:]
    reshape_kernel = neight2channels(kernel)
    reshape1 = topi.reshape(output, [B*C, 1, Hpad, Wpad])
    conv1 = topi.nn.conv2d(reshape1, reshape_kernel, 1, 0, 1)
    out1 = topi.min(conv1, 1)
    reshape2 = topi.reshape(neighborhood, [1])
    out2 = topi.subtract(out1, reshape2)
    out = topi.reshape(out1, [B, C, H, W])
    return [A, kernel, out]


def neight2channels(kernel):
    h, w = kernel.shape
    temp = te.compute((h*w, h*w), lambda i, j: te.if_then_else(i == j, 1, 0), name="temp")
    reshape_kernel = topi.reshape(temp, [h*w, 1, h, w])
    return reshape_kernel

And here is the printed schedule:

 PadInput_i0, PadInput_i1, PadInput_i2, PadInput_i3 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
    compute_i, compute_j = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, compute_ry, compute_rx = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
    compute_red_ax0, compute_red_ax1, compute_red_ax2, compute_red_k1 = tuple(compute_red.op.axis) + tuple(compute_red.op.reduce_axis)
    T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
    compute_local, = s.cache_write([compute], "local")
    compute_local_nn_c, compute_local_ff_c, compute_local_yy_c, compute_local_xx_c, compute_local_rc, compute_local_ry, compute_local_rx = tuple(compute_local.op.axis) + tuple(compute_local.op.reduce_axis)
    compute_local_nn_c_o_i, compute_local_nn_c_i = s[compute_local].split(compute_local_nn_c, factor=3)
    compute_local_nn_c_o_o_i, compute_local_nn_c_o_i = s[compute_local].split(compute_local_nn_c_o_i, factor=1)
    compute_local_nn_c_o_o_o_i, compute_local_nn_c_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_i, factor=1)
    compute_local_nn_c_o_o_o_o, compute_local_nn_c_o_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_o_i, factor=1)
    compute_local_ff_c_o_i, compute_local_ff_c_i = s[compute_local].split(compute_local_ff_c, factor=1)
    compute_local_ff_c_o_o_i, compute_local_ff_c_o_i = s[compute_local].split(compute_local_ff_c_o_i, factor=1)
    compute_local_ff_c_o_o_o_i, compute_local_ff_c_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_i, factor=25)
    compute_local_ff_c_o_o_o_o, compute_local_ff_c_o_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_o_i, factor=1)
    compute_local_yy_c_o_i, compute_local_yy_c_i = s[compute_local].split(compute_local_yy_c, factor=1)
    compute_local_yy_c_o_o_i, compute_local_yy_c_o_i = s[compute_local].split(compute_local_yy_c_o_i, factor=1)
    compute_local_yy_c_o_o_o_i, compute_local_yy_c_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_i, factor=1)
    compute_local_yy_c_o_o_o_o, compute_local_yy_c_o_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_o_i, factor=5)
    compute_local_xx_c_o_i, compute_local_xx_c_i = s[compute_local].split(compute_local_xx_c, factor=1)
    compute_local_xx_c_o_o_i, compute_local_xx_c_o_i = s[compute_local].split(compute_local_xx_c_o_i, factor=1)
    compute_local_xx_c_o_o_o_i, compute_local_xx_c_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_i, factor=5)
    compute_local_xx_c_o_o_o_o, compute_local_xx_c_o_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_o_i, factor=1)
    compute_local_rc_o_i, compute_local_rc_i = s[compute_local].split(compute_local_rc, factor=1)
    compute_local_rc_o_o, compute_local_rc_o_i = s[compute_local].split(compute_local_rc_o_i, factor=1)
    compute_local_ry_o_i, compute_local_ry_i = s[compute_local].split(compute_local_ry, factor=5)
    compute_local_ry_o_o, compute_local_ry_o_i = s[compute_local].split(compute_local_ry_o_i, factor=1)
    compute_local_rx_o_i, compute_local_rx_i = s[compute_local].split(compute_local_rx, factor=1)
    compute_local_rx_o_o, compute_local_rx_o_i = s[compute_local].split(compute_local_rx_o_i, factor=1)
    s[compute_local].reorder(compute_local_nn_c_o_o_o_o, compute_local_ff_c_o_o_o_o, compute_local_yy_c_o_o_o_o, compute_local_xx_c_o_o_o_o, compute_local_nn_c_o_o_o_i, compute_local_ff_c_o_o_o_i, compute_local_yy_c_o_o_o_i, compute_local_xx_c_o_o_o_i, compute_local_nn_c_o_o_i, compute_local_ff_c_o_o_i, compute_local_yy_c_o_o_i, compute_local_xx_c_o_o_i, compute_local_rc_o_o, compute_local_ry_o_o, compute_local_rx_o_o, compute_local_rc_o_i, compute_local_ry_o_i, compute_local_rx_o_i, compute_local_nn_c_o_i, compute_local_ff_c_o_i, compute_local_yy_c_o_i, compute_local_xx_c_o_i, compute_local_rc_i, compute_local_ry_i, compute_local_rx_i, compute_local_nn_c_i, compute_local_ff_c_i, compute_local_yy_c_i, compute_local_xx_c_i)
    compute_nn_o_i, compute_nn_i = s[compute].split(compute_nn, factor=3)
    compute_nn_o_o_i, compute_nn_o_i = s[compute].split(compute_nn_o_i, factor=1)
    compute_nn_o_o_o, compute_nn_o_o_i = s[compute].split(compute_nn_o_o_i, factor=1)
    compute_ff_o_i, compute_ff_i = s[compute].split(compute_ff, factor=1)
    compute_ff_o_o_i, compute_ff_o_i = s[compute].split(compute_ff_o_i, factor=25)
    compute_ff_o_o_o, compute_ff_o_o_i = s[compute].split(compute_ff_o_o_i, factor=1)
    compute_yy_o_i, compute_yy_i = s[compute].split(compute_yy, factor=1)
    compute_yy_o_o_i, compute_yy_o_i = s[compute].split(compute_yy_o_i, factor=1)
    compute_yy_o_o_o, compute_yy_o_o_i = s[compute].split(compute_yy_o_o_i, factor=5)
    compute_xx_o_i, compute_xx_i = s[compute].split(compute_xx, factor=1)
    compute_xx_o_o_i, compute_xx_o_i = s[compute].split(compute_xx_o_i, factor=5)
    compute_xx_o_o_o, compute_xx_o_o_i = s[compute].split(compute_xx_o_o_i, factor=1)
    s[compute].reorder(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o, compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i, compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i, compute_nn_i, compute_ff_i, compute_yy_i, compute_xx_i)
    s[compute_local].compute_at(s[compute], compute_xx_o_i)
    T_reshape_shared = s.cache_read(T_reshape, "shared", [compute_local])
    T_reshape_shared_ax0, T_reshape_shared_ax1, T_reshape_shared_ax2, T_reshape_shared_ax3 = tuple(T_reshape_shared.op.axis)
    s[T_reshape_shared].compute_at(s[compute_local], compute_local_rx_o_o)
    s[T_reshape].compute_inline()
    s[compute].compute_inline()
    pad_temp_shared = s.cache_read(pad_temp, "shared", [compute_local])
    pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
    s[pad_temp_shared].compute_at(s[compute_local], compute_local_rx_o_o)
    s[pad_temp].compute_inline()
    s[T_reshape].compute_inline()
    s[PadInput].compute_inline()
    T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused = s[T_reshape].fuse(T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3)
    T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[T_reshape].split(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused, factor=32)
    s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, te.thread_axis("blockIdx.x"))
    s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i, te.thread_axis("threadIdx.x"))
    compute_red_ax0_ax1_fused_ax2_fused = s[compute_red].fuse(compute_red_ax0, compute_red_ax1, compute_red_ax2)
    compute_red_ax0_ax1_fused_ax2_fused_o, compute_red_ax0_ax1_fused_ax2_fused_i = s[compute_red].split(compute_red_ax0_ax1_fused_ax2_fused, factor=64)
    s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_o, te.thread_axis("blockIdx.x"))
    s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_i, te.thread_axis("threadIdx.x"))
    compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused = s[compute].fuse(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o)
    s[compute].bind(compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused, te.thread_axis("blockIdx.x"))
    compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused = s[compute].fuse(compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i)
    s[compute].bind(compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused, te.thread_axis("vthread"))
    compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused = s[compute].fuse(compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i)
    s[compute].bind(compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused, te.thread_axis("threadIdx.x"))
    T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[T_reshape_shared].fuse(T_reshape_shared_ax0, T_reshape_shared_ax1, T_reshape_shared_ax2, T_reshape_shared_ax3)
    T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
    s[T_reshape_shared].vectorize(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
    T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=125)
    s[T_reshape_shared].bind(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
    pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)
    pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)
    s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
    pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=125)
    s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
    s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "auto_unroll_max_step", 512)
    s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "unroll_explicit", True)
    s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, "auto_unroll_max_step", 64)
    s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, "unroll_explicit", True)

So,since the fusion,how can I use the printed schedule again?

Ah you’re right. In this case auto-scheduler doesn’t differentiate reshape ops, because their names are defined in the TOPI compute (and are the same…). I had this issue before but it was for cache_read, which was added by the auto-scheduler so I could let auto-scheduler control its name.

We may need to improve the printer here (https://github.com/apache/tvm/blob/main/src/auto_scheduler/compute_dag.cc#L1206) by adding a renaming logic. I’ll try to take a look if I got time in this week.

I just realized that this problem is even more complicate. If we simply rename the second T_reshape to T_reshape_1, then you still cannot match the TE compute unless we rename the corresponding one in the compute DAG as well. However, printing the compute DAG and schedule are different functions and it would be unsafe to maintain a such mapping. While the meta schedule that @junrushao et all are working on has resolved this issue, I would say the best way for now is probably just to manually match them, as we know that the line with ... = tuple(name.op.axis) + tuple(name.op.reduce_axis) is always the starting point of scheduling a stage.

1 Like

Yes, I realize the same question that since the auto_scheduler do the fusion, the DAG op can’t correspond to thoes we wrote in TE and TOPI.But the printed schedule use the DAG op,so we can’t simply match or reuse the printed schedule.

However, when I look into the DAG, I find it doesn’t change a lot from TE.So if I write a new TE according to the DAG,the schedule can work for it, is that right?

And here’s the DAG:

> A = PLACEHOLDER [1, 3, 5, 5]
> PadInput(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 2) && (i2 < 7)) && (i3 >= 2)) && (i3 < 7)), A[i0, i1, (i2 - 2), (i3 - 2)], 10000f)
> T_reshape(ax0, ax1, ax2, ax3) = PadInput[0, floormod(floordiv(floordiv((((((ax0 + ax1)*9) + ax2)*9) + ax3), 9), 9), 3), floormod(floordiv((((((ax0 + ax1)*9) + ax2)*9) + ax3), 9), 9), floormod((((((ax0 + ax1)*9) + ax2)*9) + ax3), 9)]
> pad_temp(i0, i1, i2, i3) = T_reshape[i0, i1, i2, i3]
> temp(i, j) = tir.if_then_else((i == j), 1, 0)
> T_reshape(ax0, ax1, ax2, ax3) = temp[floormod(floordiv((((((ax0 + ax1)*5) + ax2)*5) + ax3), 25), 25), floormod((((((ax0 + ax1)*5) + ax2)*5) + ax3), 25)]
> compute(nn, ff, yy, xx) += (pad_temp[nn, rc, (yy + ry), (xx + rx)]*float32(T_reshape[ff, rc, ry, rx]))
> compute_red(ax0, ax1, ax2) min= compute[ax0, k1, ax1, ax2]
> T_reshape(ax0, ax1, ax2, ax3) = compute_red[floormod(floordiv(floordiv(((((((ax0*3) + ax1)*5) + ax2)*5) + ax3), 5), 5), 3), floormod(floordiv(((((((ax0*3) + ax1)*5) + ax2)*5) + ax3), 5), 5), floormod(((((((ax0*3) + ax1)*5) + ax2)*5) + ax3), 5)]

Yes I actually did this very often, although it’s much simpler for me because I don’t have multiple identity ops.

OK,thank you so much. :grinning: BTW, do you know that if te.schedule contains the compute information or just schedule? I mean the sch in sch, args = task.apply_best(log_file). Because I tried to change the args and got an error.

Perhaps adding a renaming logic here may not work because there are just the print of ... = tuple(name.op.axis) + tuple(name.op.reduce_axis), but the following steps in step print still be the same.Maybe we should change the op->name?

Like I mentioned op name is defined in the TOPI compute, so auto-scheduler cannot change it. To solve this issue, we need to introduce a general utility to rename op names in a TE compute.