Manual scheduling of call_tir functions in Relax

I’ve been playing around with Relax and I know how the metascheduler can be used in different ways for Relax, but just out of curiosity, I was wondering if there any easy ways to write manual schedules for prim_funcs in a Relax IRModule.

Is it possible to somehow iterate through the prim_funcs and manually schedule them. I think I can schedule them manually by scheduling each prim_func in the Relax IRModule and then use a pass to replace the original prim_funcs with the new scheduled one, but is there any simpler way?

Yep, this is possible to mix different ways of scheduling (manual, handcrafted search space, auto-generated space) with Relax. To do this, you may do task extraction, and customize it, and then plug it back.

Example:

extracted_tasks = ms.relax_integration.extract_tasks(mod, target, params=None)
tasks, task_weights = inject_schedule(extracted_tasks)
                      ^^^^^^^^^^^^^^^
                      customize the way we schedule
database = ms.tune_tasks(
     tasks=tasks,
     task_weights=task_weights,
     work_dir=WORK_DIR,
     max_trials_global=500,
     num_trials_per_iter=32,
)
relax_exec = ms.relax_integration.compile_relax(
    database,
    mod=mod,
    target=target,
    params=None,
)

A typical way of injecting schedules looks like:

def inject_schedule(extracted_tasks, work_dir):
    tasks = []
    task_weights = []
    for task, logger, rand_state in zip(
        extracted_tasks,
        ms.logging.get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]),
        ms.utils.fork_seed(None, n=len(extracted_tasks)),
    ):
        if task.task_name == "XXX":
           # handcrafted search space
           # or handcrafted schedule without search
            space = ms.space_generator.ScheduleFn(
                sch_fn = my_schedule_function,
                sch_rules=[],
                postprocs=[],
            )
        else:
            # auto-generated search space 
            space = "post-order-apply"
        tasks.append(
            ms.TuneContext(
                mod=task.dispatched[0],
                target=task.target,
                space_generator=space,
                task_name=task.task_name,
                logger=logger,
                rand_state=rand_state,
            ).clone()
        )
        task_weights.append(task.weight)
    return tasks, task_weights

2 Likes

Thanks for the reply @junrushao.

This should be good enough for a starting point, but just out of curiosity, is it at all possible to avoid meta-schedule completely if we’re going to do manual scheduling for all tasks?

If we’re scheduling the whole graph manually, we don’t really need to run through the builder/runner workflow or run any tuning right? But without tuning, we don’t get the database to be passed into compile_relax.

Even if we would love to use tuning for most models, there might always be some cases where users might prefer manual scheduling, hence this question about avoiding meta-schedule in those cases.

Yep, if you only wanted to do manual schedule without any search, you may simply extract the PrimFunc you’re interested in using mod["func"], creating the schedule using sch = tir.Schedule(func), schedule it, and plug it back :slight_smile:

1 Like

Thanks again @junrushao, when you say “plug it back”, I think it can be done by writing a PyExprMutator pass and then use the self.builder_ to replace the function after scheduling. Is there a better way?

Actually I found that I can just create a BlockBuilder directly and use that to update a function with a modified one, so that should be good enough, thanks.

There is actually a simpler trick if we really want “manual” interactive debugging

from tvm import relax
from tvm.contrib import utils, tvmjs
from tvm.script import relax as R


def get_model():
    pipeline = relax.get_pipeline()

    @tvm.script.ir_module
    class Mod:
        @R.function
        def main(x: R.Tensor([1024], "float32"), 
                 y: R.Tensor([1024], "float32")):
            lv0 = R.add(x, y)
            return lv0
    # legalizes to TIR function add
    mod = pipeline(Mod)
    # look at code before transform
    mos.show()
    sch = tvm.tir.Schedule(mod)
    # manually transform loop
    sch.work_on("add")
    (i,) = sch.get_loops(block=sch.get_block("T_add"))
    i0, i1 = sch.split(i, [None, 128])
    sch.bind(i0, "blockIdx.x")
    sch.bind(i1, "threadIdx.x")
    return sch.mod

You only need to call sch.work_on(func_name) work on it, and return sch.mod, which will return an IRModule that contains the updated function and other functions

8 Likes

Wow. Yeah that does seem simpler and probably cleaner. Thanks a lot @tqchen