Manual scheduling of call_tir functions in Relax

There is actually a simpler trick if we really want “manual” interactive debugging

from tvm import relax
from tvm.contrib import utils, tvmjs
from tvm.script import relax as R


def get_model():
    pipeline = relax.get_pipeline()

    @tvm.script.ir_module
    class Mod:
        @R.function
        def main(x: R.Tensor([1024], "float32"), 
                 y: R.Tensor([1024], "float32")):
            lv0 = R.add(x, y)
            return lv0
    # legalizes to TIR function add
    mod = pipeline(Mod)
    # look at code before transform
    mos.show()
    sch = tvm.tir.Schedule(mod)
    # manually transform loop
    sch.work_on("add")
    (i,) = sch.get_loops(block=sch.get_block("T_add"))
    i0, i1 = sch.split(i, [None, 128])
    sch.bind(i0, "blockIdx.x")
    sch.bind(i1, "threadIdx.x")
    return sch.mod

You only need to call sch.work_on(func_name) work on it, and return sch.mod, which will return an IRModule that contains the updated function and other functions

8 Likes