[AutoScheduler] How to enable optimization passes for AutoScheduler?

Hi, I’m using AutoScheduler for tuning my kernels, however I still see that the tuned results have further optimization opportunities. For example, a tensor with 2 continuous topi.reshape, the generated code would not eliminate the previous one. I’m wondering is there an option to enable IR optimization passes before/during/after tuning?

1 Like

This is not expected. The second reshape should be inlined by this rule https://github.com/apache/tvm/blob/837d5d4e5273ffb73dcae21c64762143a7b560c1/src/auto_scheduler/search_policy/sketch_policy_rules.h#L101-L105

Could you share the compute_dag of your task?

Hi, thanks for your reply! The Python code is as following:

import tvm
from tvm import auto_scheduler, te, topi


@auto_scheduler.register_workload
def double_reshape(n: int, c: int, h: int, w: int):
    X = te.placeholder((n, c, h, w))
    Y = topi.reshape(X, (n, 1, c, h, w))
    Z = topi.reshape(Y, (n, c, h, w))
    return [X, Z]


if __name__ == '__main__':
    target = tvm.target.Target(target='llvm', host='llvm')
    task = auto_scheduler.SearchTask(func=double_reshape, args=(32, 3, 32, 32), target=target)
    print(task.compute_dag)
    tune_options = auto_scheduler.TuningOptions(
        num_measure_trials=256,
        num_measures_per_round=64,
        verbose=2,
    )
    task.tune(tune_options)

The compute_dag is like:

placeholder = PLACEHOLDER [32, 3, 32, 32]
T_reshape(ax0, ax1, ax2, ax3, ax4) = placeholder[floormod(floordiv(floordiv(floordiv((((((((ax0 + ax1)*3) + ax2)*32) + ax3)*32)  ..(OMITTED).. iv((((((((ax0 + ax1)*3) + ax2)*32) + ax3)*32) + ax4), 32), 32), floormod((((((((ax0 + ax1)*3) + ax2)*32) + ax3)*32) + ax4), 32)]
T_reshape(ax0, ax1, ax2, ax3) = T_reshape[floormod(floordiv(floordiv(floordiv(((((((ax0*3) + ax1)*32) + ax2)*32) + ax3), 32), 32 ..(OMITTED)..  floormod(floordiv(((((((ax0*3) + ax1)*32) + ax2)*32) + ax3), 32), 32), floormod(((((((ax0*3) + ax1)*32) + ax2)*32) + ax3), 32)]

And all searched schedules have two reshape in the code, without eliminating the first reshape.