[DISCUSS] BYOC under relax

This is a discuss post to collect our questions and examples around the approach to do BYOC relax.

Some of the information so far:

Possible Approaches

Right now BYOC is implemented as a hook in the TECompiler. Extending BYOC itself with new capabilities would amount to adding additional complexity to the hooks themselves. With the unified abstraction, we can simplify BYOC to an IRModule⇒IRModule transformation.

For a given MyMod, we want to run BYOC on conv_relu and replace it with an external compiler implemented version.

@tvm.script.ir_module
class MyMod:
    @R.func
    def conv_relu(x: R.Tensor((1, 10, 32, 32)),
                  w: R.Tensor((10, 20, 3, 3))):
        with dataflow():
            lv0 = op.conv2d(x, w, padding=(1,1))
            gv0 = op.relu(lv0)
            R.output(gv0)
        return gv0
							
    @R.func
    def main(x: R.Tensor((1, 10, 32, 32))):
        w0 = R.const(shape=(10, 32, 3, 3))
        lv0: R.Tensor((1, 20, 32, 32))  = conv_relu(x, w)
        ...

The transformation takes the module, runs a code generator for conv_relu subgraph, agree on an external global symbol name to call into it (say conv_relu_lib)

@tvm.script.ir_module
class MyModWithBYOCExtern:							
    @R.func
    def main(x: R.Tensor((1, 10, 32, 32))):
        w0 = R.const(shape=(10, 32, 3, 3))
        lv0 = R.call_dps_packed(R.extern("conv_relu_lib"), [x, w], (1, 20, 32, 32))
        ...

Here we change conv_relu to call into an external PackedFunc, where input/output are allocated on the original side and passed to the library function. Semantically, call tir expands to

def call_dps_packed(func, inputs, out_shape):
    out = alloc_tensor(out_shape)
    func(*inputs, out)
    return out

Additionally, the BYOC pass will generate additional runtime.Module that contains an implementation of conv_relu_lib, which is attached to MyModWithBYOCExtern.attrs[”external_mods”] attribute as per current BYOC convention. The final build will compile the main function MyModWithBYOCExtern along with TIR functions.

There are many advantages of unifying BYOC as IRModule⇒IRModule pass:

  • We can choose when to run BYOC, and interpolate multiple BYOC runs if necessary.
  • Smart search and auto-tuning can be built on top as separate passes that make the decision and dispatch without overburdening the infrastructure.
  • The overall connection point is the IRModule spec: we represent BYOC result as call into PackedFunc and IRModules in external_mods attrs, we can explore many ways of smart BYOC planning without increasing the complexity of the interface.

Some related resources

Possible Future Items

  • Provide runnable tutorials about BYOC on relax
  • What are extra questions around the pattern rewritting and replacement mechanisms

Please feel free too add your questions by replying to ths post

I’m aware of another approach to BYOC in Relax, based entirely on TIR. It is demonstrated in https://github.com/mlc-ai/relax/pull/83 (and other PRs in that repo).

Instead of using the Relax pattern matching language, this approaches uses TIR pattern matching. So it is more like tensorization. Although I’d say the two approaches are complementary, I believe the graph-level one is more useful in practice. If the TIR-based one has definite advantages, I want to hear about them.

And since the Relax pattern matching language has some support for TIR matching (e.g. is_call_tir etc), I wonder if we can unify the two approaches.

1 Like

agree about thinkings around further unifications :slight_smile: I think they are complementary as well

As a reference, we now published a tutorial that shows how to use BYOC in Unity: [Unity][Tutorial] TVM Unity BYOC