Hello everyone!
I’m exploring how to run LLMs on Hexagon, and I began with a GEMV operator. I’ve tested a script on an 8G2 device. The results are interesting: it performs well with N = K = 512
, but slows down significantly with N = K = 1024
.
# N = K = 512
Time (ms): 0.0078 Total Bytes (MB): 0.50 Memory (GB/s): 64.35
# N = K = 1024
Time (ms): 0.1753 Total Bytes (MB): 2.00 Memory (GB/s): 11.43
The scheduled and the lowered func are:
@T.prim_func
def func(A: T.Buffer((1, 1, 1024), "float16"), B: T.Buffer((1024, 1024), "float16"), C: T.Buffer((1, 1, 1024), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(1, 1):
for i2_0 in T.parallel(4):
for i2_1 in range(2):
for i2_2_init in T.unroll(2):
for i2_3_init in T.vectorized(64):
with T.block("NT_matmul_init"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
v_i2 = T.axis.spatial(1024, i2_0 * 256 + i2_1 * 128 + i2_2_init * 64 + i2_3_init)
T.reads()
T.writes(C[v_i0, v_i1, v_i2])
C[v_i0, v_i1, v_i2] = T.float16(0)
for k_0 in range(256):
for i2_2 in T.unroll(2):
for k_1 in range(4):
for i2_3 in T.vectorized(64):
with T.block("NT_matmul_update"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
v_i2 = T.axis.spatial(1024, i2_0 * 256 + i2_1 * 128 + i2_2 * 64 + i2_3)
v_k = T.axis.reduce(1024, k_0 * 4 + k_1)
T.reads(C[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_k], B[v_k, v_i2])
T.writes(C[v_i0, v_i1, v_i2])
C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, 1, 1024), "float16"), B: T.Buffer((1024, 1024), "float16"), C: T.Buffer((1, 1, 1024), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
for i2_0 in T.parallel(4):
for i2_1 in range(2):
cse_var_1: T.int32 = i2_0 * 256 + i2_1 * 128
C_1 = T.Buffer((1024,), "float16", data=C.data)
C_1[cse_var_1:cse_var_1 + 64] = T.Broadcast(T.float16(0), 64)
C_1[cse_var_1 + 64:cse_var_1 + 64 + 64] = T.Broadcast(T.float16(0), 64)
for k_0 in range(256):
A_1 = T.Buffer((1024,), "float16", data=A.data)
B_1 = T.Buffer((1048576,), "float16", data=B.data)
for k_1 in range(4):
C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[k_0 * 4 + k_1], 64) * B_1[k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128:k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64]
for k_1 in range(4):
cse_var_2: T.int32 = cse_var_1 + 64
C_1[cse_var_2:cse_var_2 + 64] = C_1[cse_var_2:cse_var_2 + 64] + T.Broadcast(A_1[k_0 * 4 + k_1], 64) * B_1[k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64:k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64 + 64]
Since this is my first time working with Hexagon and I have little knowledge to the target, I have a few noob questions:
- Could the performance difference be due to cache misses? I thought I had ensured continuous and aligned data access.
- Are there any examples of optimizing GEMM/GEMV on HVX? I found some i8 schedules, but LLM tasks seem to favor fp16.
- I attempted to use VTCM by caching buffer A to on-chip memory, but it made things slower. Is this expected?
Any insights would be greatly appreciated. Thank you!