Currently, there is an IR pass to combine parallel Conv2D ops. This is a kind of fusion strategy that reduces the number of kernels launched, improves data locality for those kernels, and reduces latency.
Models like BERT start each layer with three parallel branches that all utilize the same input and have the same sequence of operations. These branches start with a matmul op. By fusing theese matmuls, we can go from three sequential matrix multiplications of size (128,768) x (768,768) to one batch matrix multiplication of size (3,128,768) x (3,768,768). On GPU and multi-core CPU, this should provide significant speedup.
I propose to make an IR pass called CombineParallelDense, which will combine parallel Dense ops into one BatchMatMul op. This will be optionally followed by a single batch Add op if the âunitsâ parameter of Dense is not null. This combine can be done in a very similar way to CombineParallelConv2D, and can even fuse the element-wise operations following it.
What do you think? How do you think the implementation should look so we donât have to copy code between the two âcombineâ passes?
Here are two of the parallel branches that use the same input (the third branch is very far to the right in Netron, so itâs not shown here )
Do you mean optional as in being part of optimization level 4?
Reusing the code should be doable. The tricky parts are:
CombineParallelConv2D concatenates the input then splits output. CombineParallelDense will stack the input then slice the output.
CombineParallelConv2D has some extra logic based on the size of the channel dimension. This doesnât apply to CombineParallelDense, because the size of the matrix multiplications need to match exactly.
Maybe we can make an abstract CombineParallelOp class, and have the methods take in an optional argument map, so implementations like CombineParallelConv2D can take in the channel size as an argument.
Yes we can put it in level 4 and invoke it based on profile result (like autotvm).
It would be great if we can refactor CombineParallelConv2D pass and extract common parts between two passes
You mentioned using AutoTVM to decide whether or not to run this pass. Is there an example I can go off of? I thought that CombineParallelConv2D was always invoked if you choose opt level 4.
Also, how do you think I should handle combining the element-wise ops at the end? For example, normally, the output of shape (128,768) would be added with a bias tensor of shape (768). This 1D tensor would be broadcasted. However, when the output tensor is of shape (3,128,768), I donât think we can properly broadcast-add a tensor of shape (3,768).
Unfortunately there are no examples right now. For the broadcast ops, we can pad â1â to broadcasted dimension to make lhs and rhs have same number of dimensions, meaning that (768,) will be casted to (1, 768).
Sorry, would you be able to expand on that last part? I donât quite understand. We basically have three different bias adds for each inner matrix in (3,128,768).
Do you mean the bias adds will be part of a new stacked tensor of shape (3,1,768)?
I believe that using a library like MKLâs batch matrix multiplication can be more efficient than running multiple matrix multiplications in sequence. This depends on the problem size too.