What is vectorization and is it a just hint?

Hi, I’m confused about the vectorize schedule. Two questions:

  1. Is it just a hint for TVM (i.e., it is fine for TVM to not follow it), or it’s something that will be strictly enforced?

  2. What’s the exact meaning of vectorizing a loop? I thought it was partitioning the loop into multiple vector load/stores. But when I try to vectorize a loop of 1000, the script below only worked for LLVM target but failed on CUDA. The error message is TVMError: Cannot convert type int32x1000 to CUDA type, which sounds like it was trying to convert the whole loop into one vector load/store?

import tvm
from tvm import te
import numpy as np

# tgt = tvm.target.Target(target="cuda", host="llvm") # failed: TVMError: Cannot convert type int32x1000 to CUDA type
tgt = tvm.target.Target(target="llvm", host="llvm") # succeeded
dev = tvm.device(tgt.kind.name, 0)

# get te
N = 30000
A = te.compute((N,), lambda i: i, name="A")
s = te.create_schedule(A.op)
oi, ii = s[A].split(A.op.axis[0], factor=1000)
if tgt.kind.name == 'cuda': s[A].bind(oi, te.thread_axis("threadIdx.x"))
s[A].vectorize(ii)

# run
print(tvm.lower(s, [A], simple_mode=True))
foo = tvm.build(s, [A], tgt, name="foo")
dev = tvm.cuda() if tgt.kind.name == 'cuda' else tvm.cpu()
a = tvm.nd.array(np.zeros(N,).astype(A.dtype), dev)
foo(a)

It’s not a hint, but the codegen may fail to vectorize if the required conditions doesn’t meet. When failure, you can see two outcomes: It is either no effect and still process in sequential, or throw an error as you reported.

In terms of CUDA, AFAIK, vectorization is only effective when dtype is float16. In this case, CUDA codegen will put two half values together and make use of half2 to better utilize the bandwidth. As a result, the valid vectorize size has to be dividable. You can refer to the discussion here: [CUDA] Enable half2 in CUDA injective schedule - #4 by vinx13.

2 Likes

Thanks for the explanation! Regarding the dtype, the error is similar (TVMError: Cannot convert type float16x1024 to CUDA type) when I switched to float16 and changed a bit to make everything divisible (basically 256 blocks each with 32 threads, each handling 1024 float16 loads). I think it’s a valid schedule, and based on what you said, shouldn’t the expectation be to generate 512 half2 loads within each thread?

I think the core question is: is a vectorize schedule valid if the vectorize size is larger than the maximum vector load size of the backend? On CUDA it seems to always try to generate a vectorized load of 1024 halfs=1024*16 bits, which I don’t think any machine supports, but for LLVM it works fine. So my current hypothesis is that the vectorize schedule does accept overly large vectorization size, and ideally it will get split into multiple small vectorized loads (that’s why it works for LLVM backend), but somehow this part is missing for CUDA backend.

BTW I printed the TIR before codgen, which for both backends includes something similar to B_2[ramp(((i.outer.outer*32768) + (i.outer.inner*1024)), 1, 1024)] = (float16x1024*)A_2[ramp(((i.outer.outer*32768) + (i.outer.inner*1024)), 1, 1024)] that has a pattern of ramp(xxx, 1, 1024). The LLVM codegen does handle this overly long ramp well, but not for CUDA codegen, which is inconsistent and could be a bug?

import tvm
from tvm import te
import numpy as np

tgt = tvm.target.Target(target="cuda", host="llvm") # failed: TVMError: Cannot convert type int32x1000 to CUDA type
# tgt = tvm.target.Target(target="llvm", host="llvm") # succeeded
dev = tvm.device(tgt.kind.name, 0)

# get te
VEC_WIDTH, NBLOCKS, NTHRDS = 1024, 256, 32
N = NTHRDS * VEC_WIDTH * NBLOCKS
# A = te.compute((N,), lambda i: i, name="A")
A = te.placeholder((N,), dtype='float16', name='A')
B = te.compute((N,), lambda i: A[i], name='B')
s = te.create_schedule(B.op)
oi, ii = s[B].split(B.op.axis[0], factor=VEC_WIDTH)
bx, tx = s[B].split(oi, factor=NTHRDS)
if tgt.kind.name == 'cuda': 
    s[B].bind(bx, te.thread_axis("blockIdx.x"))
    s[B].bind(tx, te.thread_axis("threadIdx.x"))
s[B].vectorize(ii)

# run
print(tvm.lower(s, [A, B], simple_mode=True))
foo = tvm.build(s, [A, B], tgt, name="foo")
dev = tvm.cuda() if tgt.kind.name == 'cuda' else tvm.cpu()
a = tvm.nd.array(np.zeros(N,).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(N,).astype(B.dtype), dev)
foo(a, b)
print('done')
1 Like

I think your hypothesis is not 100% correct in the case of CUDA codegen implementation, although it makes sense. Apparently as what you’ve found, CUDA codegen doesn’t split large vectorized loads to small ones that it could handle. This is just a limitation in the current CUDA codegen implementation, and you’re welcome to send a PR to improve it.

2 Likes