Developing a faster schedule for Longformer's kernel

We recently released a transformer model for long documents that is powered by a custom CUDA kernel implemented in TVM (Here’s TVM account tweeting about it).

Would anyone be interested in implementing a faster schedule for the kernel? I think it will be a great showcase for the usability and efficiency of TVM, and can have a big impact on the NLP community.

In case anyone is interested, here is some background:

  • The kernel is a form of banded matrix multiplication where we only compute certain diagonals of the output matrix (check figures 2.b, 2.c in the paper).

  • Our schedule here is 16x slower than it should be.

  • the batched_matmul schedule here is 2x faster than ours for the setting in figure 2.b (I will use it instead of our schedule for this case), and it is much worse than our schedule for the setting in figure 2.c.

So the question is if we can implement a schedule faster than ours and batched_matmul. If anyone is interested in working on this, please let me know.

Thanks

2 Likes

If the workload is specialized, it is easier to find better schedule. Could you provide workloads of the model in a pretrained model?

Sure. Referring to figure 2.c in the paper, you can assume a typical workload is:

  • batch_size < 20 (dimension not shown in the figure)
  • embedding size: 768 (dimension not shown in the figure)
  • sequence length (width and hight of the matrix) is multiple of 512 and ranges from 512 to 512x64, and shorter sequences are more common than longer sequences
  • window size (number of green dots in a row) is usually power of two, ranges from 32 to 8,192 with 512 being the most common
  • dilation (gray gaps between green dots) < 10 with 0 (no dilation) being the most common

Is this the kind of information you need?

If you want one or two specific configurations to work with, they would be:

  • batch size = 12 (but batch_matmul_schedule didn’t require a constant batch size, so maybe this doesn’t need to be constant)
  • embedding size: 768
  • sequence length: 4,096
  • window size: 512
  • dilation: 0 and 3 (I think a lot of the locality assumptions for caching will break once we start working with non-zero dilation. That’s why we need to study both cases, 0 because it is the most common, and 3 because it is representative of the cases where locality breaks)

Could you provide these parameters?

b = tvm.var('b')  # batch size -> 12 or n
n = tvm.var('n')  # sequence length -> 512
h = tvm.var('h')  # number of heads -> ?
m = tvm.var('m')  # hidden dimension -> 768
w = tvm.var('w')  # window size -> 512
w_upper = tvm.var('w_upper')  # window size to the right of the word. Should be `0` or `w` -> ?
padding = tvm.var('padding')  # padding -> ?
transpose_t1 = tvm.var('transpose_t1')  # t1 should be transposed -> True / False
t1d3 = tvm.var('t1d3')  # last dimension of t1 -> ?
t3d3 = tvm.var('t3d3')  # last dimension of t3 (the result tensor) -> ?
1 Like

Sorry for the complicated code. I didn’t know how to compile multiple kernels into one .so file so ended up cramming three functions (one for the forward pass and 2 for the backward) into one with flags to switch between them. Here are the constants for the forward pass

b = 1  # batch size
n = 4096  # sequence length
h = 12  # number of heads (this dimension can be merged with the batch size if needed)
m = 768  # hidden dimension -> 768
w = 256  # window size on one side
w_upper = 256  # window size to the right of the word. Should be `w` for the non-autoregressive case
padding = 0 # padding -> any const
transpose_t1 = 0  # `0` for one of the backward functions and `1` for the other, doesn't matter for the forward
t1d3 = 768  # last dimension of t1 -> this is `m` for the forward function and `2w+1` (number of diagonals) for the backward
t3d3 = 513  # last dimensions of t3, this is 2w+1 for the forward pass