The way I understand the FuseOps pass is that it creates, from a Relay chain of operators, a multistage operator to be lowered to TIR
//Before FuseOps
op1(args_op1){
//TIR of op1
}
op2(args_op2){
//TIR of op2
}
//Note: these would be handled by 2 different lowering processes (i.e. 2 operators each single stage)
//After FuseOp
op1_op2([args_op1,args_op2]){
//TIR of op1
//TIR of op2
}
//Note: these would be handled by 1 lowering process (i.e. 1 operator with 2 stages)
In some literature this is called “operator stacking”. I think the terminology is very transparent of what it does. It literally creates one function which is the stacked version of each. With no real melting of the schedules.
In the TVM tech report, the following is stated (and I assume that the current FuseOps implementation does the same)
We provide generic rules to fuse these operators, as follows.Multiple injective operators can be fused into another injective operator. A reduction operator can be fused with input injective operators (e.g., fuse scale and sum). Operators such as conv2d are complex-out-fusable, and we can fuse element-wise operators to its output.
In the VTA example as Thierry said
At least the clip/cast are inserted by the quantization process of the network (or at least that is the case in the deploy_detection tutorial).
Since these ops are e-wise and the FuseOps routine can fuse the Conv2D to e-wise operators, then they get a “stacked” version of the Conv2d+following_e-wise, when lowering the conv2d Relay operator. Then create a schedule like:
conv2d_ewise([args_conv2d,args_e_wise1,...]){
//TIR of Intertile Conv2d
//The indentation is to illustrate that the next TIRs are in subscopes (i.e. inside {}) of the intertile Conv2d
//TIR of DMA loads
//TIR of Intratile Conv2d
//TIR of Intratile e_wise1
//TIR of Intratile e_wise2
//...
//TIR of DMA stores
}
This they do by te.compute_at()
statements in their conv2d schedule. But again these “stages” are available here because FuseOps supports them.
If the FuseOps was not able to fuse conv2d with following e-wise operators, then the VTA implementation would have had to either:
- O1: “Inject” the e-wise ops into the conv2d schedules when lowering to TIR. This would be a problem if they required some variables available at Relay level which are not part of the conv2d argument list.
- O2: Create Relay ops which represent the operators as those given by the current FuseOps and do Relay level graph rewrites. Which is the suggestion most here are leaning towards
My case is similar to the above problem. The chain of Relay operators I want to first “stack” and then actually “melt” using compute_at
s is not supported by FuseOps.
My problems with defining Relay operators which are the “stacked” version of already available Relay operators are:
- P1: I don’t think the Relay documentation available for adding an operator is sufficient for me to actually do it
- P2: These new Relay operators would(?) have to be tested in order to make sure that all Relay passes still function as expected. Note that this Relay-level rewrite would need to happen at a similar level as the FuseOps pass.
- P3: These new Relay operators would most likely be HW-dependent and not necessarily portable. But I guess this can be handled by HW-dependent Relay re-writes enabled.
I think a more elegant solution would be to allow a definition of “stacked operators” at the Relay level. So basically that every MergeComposite creates the “stacked” multistage operator. If the MergeComposite is given to an external codegen then it doesn’t use this “multistage” information, but if it’s given to the TIR lowering process then a specific target can use this information.
@tqchen would you mind giving your thoughts?