[RFC] A general task extraction mechanism for auto_scheduler

Here are some more details about the interface change in this RFC. The new added use_topi_schedule flag is propagated from the compile engine to relay._build. As a result, this actually doesn’t expose to users. The use cases are the following:

  1. Use TOPI schedule with fallback values (same as current).

    with PassContext(opt_level=opt_level):
        lib = relay.build(mod, target=target, params=params)
    
  2. Use TOPI schedule with AutoTVM tuning logs (same as current).

    with autotvm.apply_history_best(log_file):
        with PassContext(opt_level=opt_level):
            lib = relay.build(mod, target=target, params=params)
    
  3. Extract auto_scheduler tasks. It calls GraphRuntimeCodegen(use_topi_schedule=False) to launch the compile engine in order to lower the Relay functions to TE compute DAGs.

    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)
    

    In extract_tasks:

    with transform.PassContext(opt_level=3):
        opt_mod, _ = relay.optimize(mod, target, params)
        grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target, use_topi_schedule=False)
        grc.codegen(opt_mod["main"])
    
  4. Use auto_scheduler tuning logs. In relay.build, it invokes relay._build(use_topi_schedule=False) because it finds the auto_scheduler.DispatchContext is not None, meaning that users want to apply the auto_scheduler log.

    with auto_scheduler.ApplyHistoryBest(log_file):
        with PassContext(opt_level=opt_level):
            lib = relay.build(mod, target=target, params=params)
    

As a result, the changes are hid from an end-user’s point of view. On the other hand, putting this flag to the PassContext results in the change of case 3 and case 4:

  1. In extract_tasks, we can still add the flag for users.

    with transform.PassContext(opt_level=3, use_topi_schedule=False):
        opt_mod, _ = relay.optimize(mod, target, params)
        grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
        grc.codegen(opt_mod["main"])
    
  2. Users have to manually add the flag anyways.

    with auto_scheduler.ApplyHistoryBest(log_file):
        with PassContext(opt_level=opt_level, use_topi_schedule=False):
            lib = relay.build(mod, target=target, params=params)
    

IMHO, this changes the interface more than the current solution.

2 Likes