This is a discuss post to collect our questions and examples around the approach to do BYOC relax.
Some of the information so far:
Possible Approaches
Right now BYOC is implemented as a hook in the TECompiler. Extending BYOC itself with new capabilities would amount to adding additional complexity to the hooks themselves. With the unified abstraction, we can simplify BYOC to an IRModule⇒IRModule transformation.
For a given MyMod
, we want to run BYOC on conv_relu
and replace it with an external compiler implemented version.
@tvm.script.ir_module
class MyMod:
@R.func
def conv_relu(x: R.Tensor((1, 10, 32, 32)),
w: R.Tensor((10, 20, 3, 3))):
with dataflow():
lv0 = op.conv2d(x, w, padding=(1,1))
gv0 = op.relu(lv0)
R.output(gv0)
return gv0
@R.func
def main(x: R.Tensor((1, 10, 32, 32))):
w0 = R.const(shape=(10, 32, 3, 3))
lv0: R.Tensor((1, 20, 32, 32)) = conv_relu(x, w)
...
The transformation takes the module, runs a code generator for conv_relu
subgraph, agree on an external global symbol name to call into it (say conv_relu_lib
)
@tvm.script.ir_module
class MyModWithBYOCExtern:
@R.func
def main(x: R.Tensor((1, 10, 32, 32))):
w0 = R.const(shape=(10, 32, 3, 3))
lv0 = R.call_dps_packed(R.extern("conv_relu_lib"), [x, w], (1, 20, 32, 32))
...
Here we change conv_relu
to call into an external PackedFunc, where input/output are allocated on the original side and passed to the library function. Semantically, call tir expands to
def call_dps_packed(func, inputs, out_shape):
out = alloc_tensor(out_shape)
func(*inputs, out)
return out
Additionally, the BYOC pass will generate additional runtime.Module
that contains an implementation of conv_relu_lib
, which is attached to MyModWithBYOCExtern.attrs[”external_mods”]
attribute as per current BYOC convention. The final build will compile the main function MyModWithBYOCExtern
along with TIR functions.
There are many advantages of unifying BYOC as IRModule⇒IRModule pass:
- We can choose when to run BYOC, and interpolate multiple BYOC runs if necessary.
- Smart search and auto-tuning can be built on top as separate passes that make the decision and dispatch without overburdening the infrastructure.
- The overall connection point is the IRModule spec: we represent BYOC result as call into PackedFunc and IRModules in
external_mods
attrs, we can explore many ways of smart BYOC planning without increasing the complexity of the interface.
Some related resources
- @sunggg provided an initial pass that implements BYOC that reuses the existing JSON runtime but leverages
- @masahi started some ground work to enable graph pattern rewrite and BYOC mechanisms https://github.com/tlc-pack/relax/issues/364
Possible Future Items
- Provide runnable tutorials about BYOC on relax
- What are extra questions around the pattern rewritting and replacement mechanisms
Please feel free too add your questions by replying to ths post