This is a discuss post to collect our questions and examples around the approach to do BYOC relax.
Some of the information so far:
Right now BYOC is implemented as a hook in the TECompiler. Extending BYOC itself with new capabilities would amount to adding additional complexity to the hooks themselves. With the unified abstraction, we can simplify BYOC to an IRModule⇒IRModule transformation.
For a given
MyMod, we want to run BYOC on
conv_relu and replace it with an external compiler implemented version.
@tvm.script.ir_module class MyMod: @R.func def conv_relu(x: R.Tensor((1, 10, 32, 32)), w: R.Tensor((10, 20, 3, 3))): with dataflow(): lv0 = op.conv2d(x, w, padding=(1,1)) gv0 = op.relu(lv0) R.output(gv0) return gv0 @R.func def main(x: R.Tensor((1, 10, 32, 32))): w0 = R.const(shape=(10, 32, 3, 3)) lv0: R.Tensor((1, 20, 32, 32)) = conv_relu(x, w) ...
The transformation takes the module, runs a code generator for
conv_relu subgraph, agree on an external global symbol name to call into it (say
@tvm.script.ir_module class MyModWithBYOCExtern: @R.func def main(x: R.Tensor((1, 10, 32, 32))): w0 = R.const(shape=(10, 32, 3, 3)) lv0 = R.call_dps_packed(R.extern("conv_relu_lib"), [x, w], (1, 20, 32, 32)) ...
Here we change
conv_relu to call into an external PackedFunc, where input/output are allocated on the original side and passed to the library function. Semantically, call tir expands to
def call_dps_packed(func, inputs, out_shape): out = alloc_tensor(out_shape) func(*inputs, out) return out
Additionally, the BYOC pass will generate additional
runtime.Module that contains an implementation of
conv_relu_lib, which is attached to
MyModWithBYOCExtern.attrs[”external_mods”] attribute as per current BYOC convention. The final build will compile the main function
MyModWithBYOCExtern along with TIR functions.
There are many advantages of unifying BYOC as IRModule⇒IRModule pass:
- We can choose when to run BYOC, and interpolate multiple BYOC runs if necessary.
- Smart search and auto-tuning can be built on top as separate passes that make the decision and dispatch without overburdening the infrastructure.
- The overall connection point is the IRModule spec: we represent BYOC result as call into PackedFunc and IRModules in
external_modsattrs, we can explore many ways of smart BYOC planning without increasing the complexity of the interface.
- @sunggg provided an initial pass that implements BYOC that reuses the existing JSON runtime but leverages
- @masahi started some ground work to enable graph pattern rewrite and BYOC mechanisms https://github.com/tlc-pack/relax/issues/364
- Provide runnable tutorials about BYOC on relax
- What are extra questions around the pattern rewritting and replacement mechanisms
Please feel free too add your questions by replying to ths post