Hello,
I am currently attempting to use TE compute as a TensorIR frontend, convert it into STIR using te.create_prim_func()
and then schedule it using a tir.Schedule
. However, even after some digging, I still cannot figure out the appropriate way of hooking up such an STIR schedule to the Relay strategy while using this approach:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC),
name="conv2d_NHWC.arm_cpu",
)
I have tried out 3 methods so far:
-
Without metaschedule, having
compute_conv2d_NHWC()
return ate.compute
output Tensor as usual andschedule_conv2d_NHWC()
return atir.Schedule
schedule, I get an error:2: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply, tvm::NameSupply) 1: tvm::relay::OpImplementation::Schedule(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Target const&) 0: tvm::te::Schedule tvm::runtime::TVMPODValue_::AsObjectRef<tvm::te::Schedule>() const File "/home/parallels/tvm/include/tvm/runtime/packed_func.h", line 2029 InternalError: Check failed: (!checked_type.defined()) is false: Expected Schedule, but got tir.Schedule
-
While using metaschedule:
with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext( opt_level=3, config={ "relay.backend.use_meta_schedule": True, "relay.backend.tir_converter": "allow_extern", }, ): lib = relay.build(relay_mod, target=target, params=params)
I noticed that
schedule_conv2d_NHWC(outs)
no longer gets called in this configuration, so I attempted scheduling elsewhere in the code:-
Directly in the
compute_conv2d_NHWC()
function:def compute_conv2d_NHWC(...): ... out = te.compute(...) func = te.create_prim_func(inputs + [out]) sch = tir.Schedule(func) ... (scheduling) ... return te.extern_primfunc(inputs, sch.mod["main"], attrs={})
However, the PrimFunc does not include the fused operators (e.g. relu, add) at this point, so I cannot inline them from here.
-
In the
schedule_fn
function that is passed as an argument toms.database.ScheduleFnDatabase(schedule_fn)
:def schedule_fn(sch): if "fused_conv2d_add_relu" in sch.mod.attrs["task_name"]: schedule_conv2d_NHWC(sch) return True return False
While the PrimFunc does include the fused operators at this point, I believe this would mean I have to specify the scheduling function for each operator again, when this information should already be captured by the Relay strategy.
-
To sum everything up in a single question, is there a way to perform STIR scheduling inside of the schedule function that is registered alongside a compute definition in the Relay strategy, or something similar, such that both conv2d and its fused operators can be scheduled together? If not, what would be the current accepted way of registering STIR in the Relay strategy? Thank you!