[RFC] Op based annotation for external codegen
Background
We (@zhiics and @comaniac) have merged the major infra and tutorials of bring-your-own-codegen. Currently, we expect users to write a custom pass to annotate a Relay program and then send it for partitioning. This offers great flexibility. However, one feature we removed from the infra is op based annotation, which allows developers/vendors to only specify if an operator is supported by their own codegen. Therefore, the annotation work is maintained by TVM so the required work from developers/vendors is minimized.
For example, we can have various methods to merge operators to form a subgraph that can be offloaded to external codegen. A straightforward approach is merging the ops in a greedy manner so that we offload a as large as possible subgraph to an external accelerator/backend. By doing this, we can ease the effort from vendors, but of course, we still allow them to bring their own pass to annotate a program if they do not satisfy with the greedy approach.
Considerations
The reason why we removed it from previous BYOC PRs was because we were thinking about how to integrate it with the op strategy (PR #4644), and 2) a proper interface to developers/vendors.
After consideration, we concluded that graph annotation for external codegen is actually orthogonal to op strategy as graph annotation is done before vendor-independent (like constant folding etc) Relay passes have been executed. Op strategy, on the other hand, mainly focuses on selecting compute and schedule functions for ops that have been determined to be compiled by TVM. As a result, op strategy happens after all passes are complete.
Also, the op based annotation is orthogonal to the recently merged composite pass (PR #4771) which accepts user-written Relay graph patterns and uses pattern matching to make subgraphs. Our proposed op based approach simplifies the work from users to write possibly countless patterns.
Implementation in TVM
For each op, we have a corresponding check function registered and the checker will be invoked at the compilation time to indicate if we should annotate the op for the 3rd party accelerator to offload. For example, the following code shows an implementation of the helper for checking if an op should be offloaded to a given compiler:
def register_external_compiler_helper(op_name):
@reg.register_external_compiler(op_name)
def _register_wrapper(attrs, args, comp):
return get_external_compiler(comp, op_name)(attrs, args)
return _register_wrapper
# Register all Relay ops
register_external_compiler_helper("nn.conv2d")
register_external_compiler_helper("nn.relu")
register_external_compiler_helper("add")
...
- Note that
comp
is a module name in the 3rd party compiler (e.g, the namednnl
indicates the modulepython/tvm/relay/backend/contrib/dnnl/dnnl.py
); theget_external_compiler
useshasattr
andgetattr
to obtain the 3rd party specified checkers (implemented in the following section).
Required Implementation by Developers/Vendors
- For HW partners/3rd party library, they only need to attach an op to a provided helper or they can also implement a simple checker for an op to specify if they could support it under certain conditions. For example, they can do any of the following:
_register_external_op_helper("conv2d") _register_external_op_helper("conv2d", True) _register_external_op_helper("conv2d", False) def conv2d(attrs, args): if ...: return True return False
- Where
_register_external_op_helper("conv2d")
is a simple version of
def conv2d(attrs, args): return True
- Note that HW partners do not need to register this function but just need to implement it under
python/tvm/relay/backend/contrib/compiler_name/comp.py
so that the function can be discovered and imported dynamically.
- Where
- A Relay IR pass,
AnnotateCompiler
, in Python will invoke above function, insert annotations to the graph, and run the greedy algorithm (TBA) to merge ops.
How to Use
mod = relay.IRModule()
seq = Sequential([transform.AnnotateCompiler("dnnl"), transform.GraphPartition()])
mod = seq(mod)
graph, lib, params = relay.build_module(mod, "llvm")
Any comments are highly appreciated:)