Hello there. The idea is just same with existing IR pass described in [Discussion] New IR pass proposal: CombineParallelDense by @jonso . Many sequential network structures conduct group of matmul operations on same input tensor such as
- gate projections on state within GRU/LSTM
- Q/K/V projections on input within transformer layer
CombineParallelDense pass such operations can be combined to fully utilize performance of matmul kernels.
The current implemented strategy is transform multiple matmul into batched matmul op:
Y_1: [M, N] = matmul(X: [M, K], W_1: [K, N]), …, Y_B = matmul(X, W_B: [K, N])
Y: [B, M, N] = batch_matmul(stack(X…), stack(W_1, … ,W_B))
However, there seems to be another simpler choice to just combine them into one matmul instead of batched matmul, and it also works with even different output channel sizes:
Y_1 = matmul(X: [M, K], W_1: [K, N_1]), …, Y_B = matmul(X, W_B: [K, N_B])
Y: [M, N_1 + N_2 + … + N_B] = matmul(X, stack(W_1, …, W_B))
Since matmul and batch_matmul are different op implementations, the performance of combined op may differ. The output layout are also different which may affect downstream ops performance.
We can conduct some comparison between matmul and equivalent batch_matmul with fixed LHS matrix. Use cublas as a reference, I find that use single cublasSgemm is significantly faster than cublasSgemmStridedBatched in certain circumstances with small B (typically 3)
The proposed strategy can be an option to current CombineParallelDense pass. And I think the basic implementation logic will highly resemble
CombineParallelConv2d. CombineParallelDense pass can now select better strategy between them to get more performance benefits.