Background
The current op lowering and legalization mechanism relies on the TVM_REGISTER_GLOBAL
function and stores all function on global runtime registry tvm::runtime::Registry
. The practice created some problem while being simple.
P0. The global registry overwhelmed by all kinds of functions including lowering functions and all other kinds of functions. It’s not easy to manage intrinsic lowering functions.
P1. The current lowering pass relies on string-based function matching, while being useful, could be hard to debug and develop as ops grow and new targets are introduced. There is no clear mechanism to perform efficient and reliable fall back mechanism with string based function matching.
P2. Some functions are not supported on given targets, requiring legalization.
Proposal
With the current op registry available, we can change the op intrinsic lowering and legalization function into a unified TVM_REGISTER_OP
fashion as follows.
TVM_REGISTER_OP("tir.exp")
.set_attr<FLegalize>("default.FLowerIntrinsic", LowerDefault)
.set_attr<FLowerIntrinsic>("default.FLegalize", LegalizeDefault)
.set_attr<FLegalize>("cuda.FLegalize", LegalizeOnCUDA)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", LowerOnCUDA)
.set_attr<FLegalize>("llvm.FLegalize", LegalizeOnLLVM)
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", LowerOnLLVM)
.set_attr<FLegalize>("vulkan.FLegalize", LegalizeOnVulkan)
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", LowerOnVulkan);
Here FLowerIntrinsic
denotes the function for intrinsic lowering and FLegalize
denotes the function for legalization. Meanwhile, we use target name based keyword to support a fall back mechanism. Specifically, the functions are stored as PackedFunc
and called during intrinsic lowering / legalization pass. All previous intrinsic lowering and legalization rules can be ported into the new mechanism.
Also, on python side, we can creat a new function called tvm.register_op
with an exmaple as follows:
@tvm.register_op("tir.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.tir.call_extern(
"int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
)
Therefore, we can have a concise and clear way to register and call intrinsic lowering and legalization functions in tir
level.
Steps to Take
S0. Implement the op registration and lowering mechanism with FLowerIntrinsic
in intrinsic lowering pass.
S1. Port all previous target based lowering rules to new mechanism.
S1.1 Port python function to support op lowering function registering.
S2. Add missing ops that could be supported in cmath
, make ops as numpy
as possible.
S3. Implement the same mechanism for FLegalize
.
S4. Add some legalization for ops in need on given targets.
Pull Request
This pull request resolved the step S0 and S1.
Please let me know any thoughts / opinions, thanks!
@junrushao @tqchen @jroesch and any other contributors who are interested!