Google JAX model relay build, will TVM support it in the future?
I think that would be an interesting project, and something that is entirely feasible.
But personally, since developing a frontend requires significant effort, we have good support for PyTorch, and PyTorch is increasing adding JAX-inspired feature, I’d rather improve our support for PyTorch.
One approach that is very straightforward is to replace Jax’s jit
with a tvm_jit
interpreter which maps the core primitives to a TVM compilation flow instead of the XLA flow. It is also possible to expose tvm_call
as a JAX primitive which allows the mixing and matching of XLA and TVM in the same program. It is probably worth prototyping to understand complexity but would require mapping the entire primitive set of HLO in order to make feasible/useful.
There is also some work around auto-tuning and integration into MetaScheduler or other automation tech required.
This might be of use to you:
https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
Thank you for the answer and sharing.
Curious about the first solution, is jvm_jit
already implemented somewhere?
I don’t think so…was not aware of this feature existing
There is not a complete version of this, but it can be pretty quickly developed using the current machinery we already support. Here is an example of writing a TVM based JIT for Jax that I was able to program in about 1.5 hours: https://github.com/jroesch/tax/tree/main.
It is quite possible to borrow much of the Jaxpr normalization code from the XLA JIT and then easily map to Relay and compile using TVM. The above code gives an initial sketch of how this would work. It shows both JIT’ing “inference” code and also a gradient of the same function.
Aha, get your point. Jax has a neat interface to its computation AST where Relay graph can be easily built based upon. There should not be any technical difficulty.
Btw, I am curious is there any plan to support reverse convertion? Like Relay → Jax. Then Jax can benefit from Relay’s power graph optimization.