How can I use topi implementation in target codegen?

Hello,

I have a problem that Metal doesn’t have built-in function such as erf. And we have a crash in runtime if the model contains the error function. I checked, and it should be possible to use fast_erf implementation without problems with accuracy. But I have a problem: I don’t know the right variant how can I replace erf on fast_erf in the generated code. I see the following options:

  1. Implement specific relay codegen for Metal and use it for replacing erf on fast_erf in relay. I don’t like this option, at least because it will be necessary to register a new pattern_table for Metal and call it from python code. I would prefer that replacing erf on fast_erf doesn’t request any additional actions from user.
  2. Add check somewhere in relay.build that if we are building model for Metal then pass which will replace erf on fast_erf should be applied. I don’t like this approach because it looks like a workaround for this problem and not the right solution.
  3. I can write a fast_erf implementation on Metal language and add this function in Metal codegen to each kernel with erf function. But I don’t like this option because we can get the same code from TIR and I don’t want to implement fast_erf function one more time for Metal codegen.
  4. Call topi::fast_erf from Metal codegen. For exampe, here, I tried to do it from intrin_rule_metal.cc. I like this option more than others, but I didn’t find how can I create a valid PrimExpr from TIR op.

Could you please help me? What is the better way to solve this problem? Maybe there is another option which will be better in this case?

Yes, I believe updating intrin_rule_metal.cc is probably the right way to legalize intrinsics for each backend in general. For this particular case, the easiest way is to use relay FastMath pass tvm/fast_math.cc at 813136401a11a49d6c15e6013c34dd822a5c4ff6 · apache/tvm · GitHub This will replace erf with fast_erf at relay level.

Thank you for your reply! Yes, I agree that the easiest way is to use relay FastMath. But I didn’t find the right way how can apply FastMath only for specific target. The options 1 and 2 were about this case about replacing erf with wast_erf at relay level. But both options looks like a workaround. Maybe I didn’t find a mechanism which allow us to add specific passes for specific target? Could you please suggest what is the right way to do such replacement?

I see, I think you can go with option 4 by refactoring topi fast_erf function tvm/elemwise.h at 813136401a11a49d6c15e6013c34dd822a5c4ff6 · apache/tvm · GitHub

Everything inside compute works with PrimExpr, so you should be able to use it from tir.erf lowering in intrin_rule_metal.cc

See for example how to implement a custom lowering tvm/intrin_rule_spirv.cc at main · apache/tvm · GitHub

Thank you! I’ll try to do it.