We don;t have to define intermediate customized ops. Instead, we directly leverage the cross-level property of relax to lower to call_tir
(when we would like to supply a customized TIR) and call_dps_packed
(when we call into a library that is registered as packed func).
See an example below. The hexagon dispatch pass takes the IRModule Before, and can choose to rewrite the concat into AfterCustomTIR
that contains a customized TIR implementation, or AfterExternalLibrary
that calls into a library function with packed func interface(that can be generated via TVM_DLL_EXPORT_TYPED_FUNC
)
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor((128, 256), "float32"),
y: R.Tensor((128, 256), "float32")
):
lv0: R.Tensor((128, 512), "float32")= R.concat((x, y), axis=1)
....
@tvm.script.ir_module
class AfterCustomTIR:
@T.prim_func
def custom_tir_concat(
x: T.Buffer((128, 256), "float32"),
y: T.Buffer((128, 256), "float32"),
z: T.Buffer((128, 512), "float32")
):
# customize TIR code goes here
@R.function
def main(
x: R.Tensor((128, 256), "float32"),
y: R.Tensor((128, 256), "float32")
):
cls = AfterCustomTIR
lv0 = R.call_tir(
cls.custom_tir_concat, (x, y),
R.Tensor((128, 512), "float32")
)
....
@tvm.script.ir_module
class AfterExternalLibrary:
@R.function
def main(
x: R.Tensor((128, 256), "float32"),
y: R.Tensor((128, 256), "float32")
):
cls = AfterExternalLibrary
lv0 = R.call_dps_packed(
"hexagon_concat_packed_func",
(x, y), R.Tensor((128, 512), "float32")
)
....
The hexagon_concat_packed_func
can also be linked in via external module via TVM_DLL_EXPORT_TYPED_FUNC
, see also [Unity][nn.Module] Support `nn.SourceModule` by junrushao · Pull Request #16006 · apache/tvm · GitHub
As we can see, the main goal is to remove boilerplate as much as possible and make things composable.