Target specific Op legalization for Relax

Hi all,

As far as I was able to understand, relax’s Op Legalization registration only allows registering one compute for each op.

This creates problems as we realize that in a specific target (hexagon target in our case), we would like to replace quite a few ops with custom (TIR) implementations or library calls. If we could just specify target specific legalization (something like te op strategy) to perform that, it might be really useful.

So my question is, is target specific legalization possible now (as far as I can tell, it’s not)?, and if not would it be okay if I add support for that?

Yes, this is something tvm unity is designed for. The main approach here is to build something composable.

  • Attach target as part of IRModule attribute
  • S0: We create a hexagon dispatch pass that optionally take some of the operators and rewrites into the customized ops
  • S1: Run the default pipeline (that also contains op legalization) to cover the rest of part

So the main design goal here is to allow us to easily plugin S0 as a separate pass and compose it with the current pipeline

1 Like

Thanks for the quick reply @tqchen.

If I understand correctly, this would mean for every customized implementation of an op, we would need to define and register a target specific op.

For example, if we wanted to replace concatenate op with a custom implementation for hexagon, along with the hexagon dispatch pass, we would also need to register it as a custom op right (something like R.vm.hexagon.concat)?

So do we have some target specific custom ops registered in our codebase and is it okay to register target specific ops which would be almost (if not exactly) the same as the default ops.

We don;t have to define intermediate customized ops. Instead, we directly leverage the cross-level property of relax to lower to call_tir(when we would like to supply a customized TIR) and call_dps_packed(when we call into a library that is registered as packed func).

See an example below. The hexagon dispatch pass takes the IRModule Before, and can choose to rewrite the concat into AfterCustomTIR that contains a customized TIR implementation, or AfterExternalLibrary that calls into a library function with packed func interface(that can be generated via TVM_DLL_EXPORT_TYPED_FUNC )


@tvm.script.ir_module
class Before:
    @R.function
    def main(
      x: R.Tensor((128, 256), "float32"),
      y: R.Tensor((128, 256), "float32")
   ):
      lv0:  R.Tensor((128, 512), "float32")= R.concat((x, y), axis=1)
     ....


@tvm.script.ir_module
class AfterCustomTIR:
    @T.prim_func
    def custom_tir_concat(
        x: T.Buffer((128, 256), "float32"), 
        y: T.Buffer((128, 256), "float32"), 
        z: T.Buffer((128, 512), "float32")
    ):
       # customize TIR code goes here

    @R.function
    def main(
      x: R.Tensor((128, 256), "float32"),
      y: R.Tensor((128, 256), "float32")
   ):
      cls = AfterCustomTIR
      lv0 = R.call_tir(
          cls.custom_tir_concat, (x, y),  
          R.Tensor((128, 512), "float32")
     )
     ....

@tvm.script.ir_module
class AfterExternalLibrary:
    @R.function
    def main(
      x: R.Tensor((128, 256), "float32"),
      y: R.Tensor((128, 256), "float32")
   ):
      cls = AfterExternalLibrary
      lv0 = R.call_dps_packed(
          "hexagon_concat_packed_func", 
         (x, y),  R.Tensor((128, 512), "float32")
     )
     ....

The hexagon_concat_packed_func can also be linked in via external module via TVM_DLL_EXPORT_TYPED_FUNC, see also [Unity][nn.Module] Support `nn.SourceModule` by junrushao · Pull Request #16006 · apache/tvm · GitHub

As we can see, the main goal is to remove boilerplate as much as possible and make things composable.

2 Likes

Ah okay, I get the idea now. Since the customization is specific to each target, target specific lowering pass can be first called for the customized version, and then the remaining can be lowered with a call to LegalizeOps.

Thanks a lot for the explanation, that makes sense.

1 Like