Hi, I am trying to add an operator to Relay as described in the Tutorial. I want to write the schedule using TIR which seems to be an issue. Here is the registration of the OP:
@tvm.ir.register_op_attr("contrib.learning.gemm", "FTVMStrategy")
def gemm_strategy_learning(attrs, inputs, out_type, target):
"""Strategy implementations for the dense layers
Args:
attrs (tvm.runtime.object.Object): attributes for the strategy
inputs (tvm.ir.container.Array): inputs
out_type (tvm.ir.tensor_type.TensorType): output type
target (tvm.target.target.Target): target for the strategy
Returns:
OpStrategy: strategies implementation
"""
if len(inputs) == 3:
strategy = OpStrategy()
strategy.add_implementation(
wrap_gemm_topi_compute(gemm_tir),
_strategy.wrap_topi_schedule(schedule_gemm_tir),
name="contrib.learning.gemm",
)
return strategy
return None
And here are the compute and schedule definitions:
def gemm_cisc_tir(
data: tvm.te.tensor.Tensor,
weight: tvm.te.tensor.Tensor,
bias: tvm.te.tensor.Tensor,
) -> tvm.te.tensor.Tensor:
"""Computation definition for my custom GEMM
Args:
data (tvm.te.tensor.Tensor): Input feature map
weight (tvm.te.tensor.Tensor): Layer weights
bias (tvm.te.tensor.Tensor): Layer biases
Returns:
tvm.te.tensor.Tensor: dense operator result
"""
# Derive shapes
ishape = topi.utils.get_const_tuple(data.shape)
wshape = topi.utils.get_const_tuple(weight.shape)
oshape = (data.shape[0], weight.shape[1])
rk = te.reduce_axis((0, wshape[0]), name="rk")
res = te.compute(
oshape,
lambda x_o, y_o: te.sum(
data[x_o, rk] * weight[rk, y_o]+ bias[y_o]),
axis=[rk],
),
name="res",
tag="dense",
)
return res
def schedule_gemm_cisc_tir(
outs: tvm.ir.container.Array
) -> tvm.tir.Schedule:
"""Schedule definition for my custom GEMM
Args:
outs (tvm.ir.container.Array): Output tensors
Returns:
tvm.te.schedule.Schedule: transformed schedule
"""
output = outs[0]
res_stage = output.op.output(0)
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
data, weight, bias = res_stage.op.input_tensors
func = te.create_prim_func([data, weight, bias, output])
sch = tvm.tir.Schedule(func)
return sch
The actual scheduling is omitted for brevity. My goal is to lower from Relay into C code, the pattern matching on Relay is already implemented. What is the correct way to go forward here? Simply returning a tir.Schedule doesn’t seem to be the right way.