Registering STIR in Relay strategy

Hello,

I am currently attempting to use TE compute as a TensorIR frontend, convert it into STIR using te.create_prim_func()and then schedule it using a tir.Schedule. However, even after some digging, I still cannot figure out the appropriate way of hooking up such an STIR schedule to the Relay strategy while using this approach:

strategy.add_implementation(
    wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC),
    wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC),
    name="conv2d_NHWC.arm_cpu",
)

I have tried out 3 methods so far:

  1. Without metaschedule, having compute_conv2d_NHWC() return a te.compute output Tensor as usual and schedule_conv2d_NHWC() return a tir.Schedule schedule, I get an error:

      2: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply, tvm::NameSupply)
      1: tvm::relay::OpImplementation::Schedule(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Target const&)
      0: tvm::te::Schedule tvm::runtime::TVMPODValue_::AsObjectRef<tvm::te::Schedule>() const
      File "/home/parallels/tvm/include/tvm/runtime/packed_func.h", line 2029
    InternalError: Check failed: (!checked_type.defined()) is false: Expected Schedule, but got tir.Schedule
    
  2. While using metaschedule:

    with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext(
        opt_level=3,
        config={
            "relay.backend.use_meta_schedule": True,
            "relay.backend.tir_converter": "allow_extern",
        },
    ):
        lib = relay.build(relay_mod, target=target, params=params)
    

    I noticed that schedule_conv2d_NHWC(outs) no longer gets called in this configuration, so I attempted scheduling elsewhere in the code:

    • Directly in the compute_conv2d_NHWC() function:

      def compute_conv2d_NHWC(...):
      	...
      	out = te.compute(...)
      	func = te.create_prim_func(inputs + [out])    
      	sch = tir.Schedule(func)
      	... (scheduling) ...
         return te.extern_primfunc(inputs, sch.mod["main"], attrs={})
      

      However, the PrimFunc does not include the fused operators (e.g. relu, add) at this point, so I cannot inline them from here.

    • In the schedule_fn function that is passed as an argument to ms.database.ScheduleFnDatabase(schedule_fn):

      def schedule_fn(sch):
      	if "fused_conv2d_add_relu" in sch.mod.attrs["task_name"]:
      		schedule_conv2d_NHWC(sch)
      		return True
      	return False
      

      While the PrimFunc does include the fused operators at this point, I believe this would mean I have to specify the scheduling function for each operator again, when this information should already be captured by the Relay strategy.

To sum everything up in a single question, is there a way to perform STIR scheduling inside of the schedule function that is registered alongside a compute definition in the Relay strategy, or something similar, such that both conv2d and its fused operators can be scheduled together? If not, what would be the current accepted way of registering STIR in the Relay strategy? Thank you!

We could write own Relay2TIR plugin to convert call of relay primitive to call of (S)TIR functions. Refer to src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc as an example. The developer is totally free to customize the workflow in this plugin.

But generally migrate to relax would be more friendly if one want to use STIR path.

1 Like