There is actually a simpler trick if we really want “manual” interactive debugging
from tvm import relax
from tvm.contrib import utils, tvmjs
from tvm.script import relax as R
def get_model():
pipeline = relax.get_pipeline()
@tvm.script.ir_module
class Mod:
@R.function
def main(x: R.Tensor([1024], "float32"),
y: R.Tensor([1024], "float32")):
lv0 = R.add(x, y)
return lv0
# legalizes to TIR function add
mod = pipeline(Mod)
# look at code before transform
mos.show()
sch = tvm.tir.Schedule(mod)
# manually transform loop
sch.work_on("add")
(i,) = sch.get_loops(block=sch.get_block("T_add"))
i0, i1 = sch.split(i, [None, 128])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
return sch.mod
You only need to call sch.work_on(func_name) work on it, and return sch.mod, which will return an IRModule that contains the updated function and other functions