Need help on Hexagon HVX Optimization

Hello everyone!

I’m exploring how to run LLMs on Hexagon, and I began with a GEMV operator. I’ve tested a script on an 8G2 device. The results are interesting: it performs well with N = K = 512, but slows down significantly with N = K = 1024.

# N = K = 512
Time (ms): 0.0078       Total Bytes (MB): 0.50  Memory (GB/s): 64.35
# N = K = 1024
Time (ms): 0.1753       Total Bytes (MB): 2.00  Memory (GB/s): 11.43

The scheduled and the lowered func are:

@T.prim_func
def func(A: T.Buffer((1, 1, 1024), "float16"), B: T.Buffer((1024, 1024), "float16"), C: T.Buffer((1, 1, 1024), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0, i1 in T.grid(1, 1):
        for i2_0 in T.parallel(4):
            for i2_1 in range(2):
                for i2_2_init in T.unroll(2):
                    for i2_3_init in T.vectorized(64):
                        with T.block("NT_matmul_init"):
                            v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                            v_i2 = T.axis.spatial(1024, i2_0 * 256 + i2_1 * 128 + i2_2_init * 64 + i2_3_init)
                            T.reads()
                            T.writes(C[v_i0, v_i1, v_i2])
                            C[v_i0, v_i1, v_i2] = T.float16(0)
                for k_0 in range(256):
                    for i2_2 in T.unroll(2):
                        for k_1 in range(4):
                            for i2_3 in T.vectorized(64):
                                with T.block("NT_matmul_update"):
                                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                                    v_i2 = T.axis.spatial(1024, i2_0 * 256 + i2_1 * 128 + i2_2 * 64 + i2_3)
                                    v_k = T.axis.reduce(1024, k_0 * 4 + k_1)
                                    T.reads(C[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_k], B[v_k, v_i2])
                                    T.writes(C[v_i0, v_i1, v_i2])
                                    C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2]
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1, 1, 1024), "float16"), B: T.Buffer((1024, 1024), "float16"), C: T.Buffer((1, 1, 1024), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        for i2_0 in T.parallel(4):
            for i2_1 in range(2):
                cse_var_1: T.int32 = i2_0 * 256 + i2_1 * 128
                C_1 = T.Buffer((1024,), "float16", data=C.data)
                C_1[cse_var_1:cse_var_1 + 64] = T.Broadcast(T.float16(0), 64)
                C_1[cse_var_1 + 64:cse_var_1 + 64 + 64] = T.Broadcast(T.float16(0), 64)
                for k_0 in range(256):
                    A_1 = T.Buffer((1024,), "float16", data=A.data)
                    B_1 = T.Buffer((1048576,), "float16", data=B.data)
                    for k_1 in range(4):
                        C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[k_0 * 4 + k_1], 64) * B_1[k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128:k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64]
                    for k_1 in range(4):
                        cse_var_2: T.int32 = cse_var_1 + 64
                        C_1[cse_var_2:cse_var_2 + 64] = C_1[cse_var_2:cse_var_2 + 64] + T.Broadcast(A_1[k_0 * 4 + k_1], 64) * B_1[k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64:k_0 * 4096 + k_1 * 1024 + i2_0 * 256 + i2_1 * 128 + 64 + 64]

Since this is my first time working with Hexagon and I have little knowledge to the target, I have a few noob questions:

  1. Could the performance difference be due to cache misses? I thought I had ensured continuous and aligned data access.
  2. Are there any examples of optimizing GEMM/GEMV on HVX? I found some i8 schedules, but LLM tasks seem to favor fp16.
  3. I attempted to use VTCM by caching buffer A to on-chip memory, but it made things slower. Is this expected?

Any insights would be greatly appreciated. Thank you!

cc @kparzysz @sanirudh

Hi @Hzfengsy,

I’m not sure about cache misses and I don’t remember examples off the top of my head, but I’ll certainly post one if I find it.

As for using VTCM, I don’t know whether this would help, but considering you’re using v73 architecture whose VTCM capacity is 8MB, we can allocate all 3 buffers (A,B and C) on VTCM and peform the computation, and atleast for the 1024 case, I’m guessing that should increase the performance.

I tried it out quickly with your script and modifying the ndarray allocation to something like this:

args = [allocate_hexagon_array(device, data=arg, mem_scope="global.vtcm") for arg in np_args]

and I’m getting the below result:

# N = K = 512
Time (ms): 0.0280       Total Bytes (MB): 0.50  Memory (GB/s): 17.93
# N = K = 1024
Time (ms): 0.1060       Total Bytes (MB): 2.00  Memory (GB/s): 18.90 

I’m not sure why the 512 version performs poorly with VTCM, that’s something I might have to dig into a bit, but for 1024 case, it does seem to improve. Without VTCM allocation, I’m seeing similar results as what you’ve posted.

1 Like