[TE][vectorize] Do we have plan to support vectorize for non-divisible split?

TVM has a rather poor performance on non-divisible dimensions. For example, a 127x127x127 GEMM op.

import tvm
from tvm import te
import numpy as np


M = N = K = 127
dtype = "float32"
target = "llvm -mcpu=skylake-avx512"
device = tvm.device(target, 0)

A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
k = te.reduce_axis((0, K), name="k")

C = te.compute((M, N), lambda i, j: te.sum( A[i, k] * B[k, j], axis=k), name="C")

s = te.create_schedule(C.op)

m, n = s[C].op.axis
k_axis = s[C].op.reduce_axis[0]

s[C].reorder(m, k_axis, n)
no, ni = s[C].split(n, factor=64)
#s[C].vectorize(ni)

print(tvm.lower(s, [A, B, C], simple_mode=True))

with tvm.target.create(target):
    func = tvm.build(s, [A, B, C])
    print(func.get_source("asm"))

evaluator = func.time_evaluator(func.entry_name, device, number=100)

a = tvm.nd.array(np.random.rand(M, K).astype(dtype), device)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), device)
c = tvm.nd.array(np.random.rand(M, N).astype(dtype), device)

print("time: %f ms, GFLOPS: %f" % (evaluator(a, b, c).mean * 1000, 2 * M * N * K / evaluator(a, b, c).mean / 1e9))

That’s mainly because “vectorize” primitive doens’t take effect on a non-divisible split, in this case, which is “ni”.

Do we have any plan to support such vectorization on non-divisible split dimension?

CC @comaniac @masahi, if you’re interested.

How about padding the input to make it divisible?

Hi, in the case can you try to eliminate condition to make it vectorizable?

with tvm.ir.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):

In my mac env, it can get 4.x times GFLOPs than origin script, but still only 1/4 compared to [128,128,128] version. So padding to regular shapes could be a better way.

1 Like

I tried this on my cascade lake server, but found it is even slower than the one before condition elimination…

This is TIR generated after eliminating condition:

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [16129], []),
             B: Buffer(B_2: Pointer(float32), float32, [16129], []),
             C: Buffer(C_2: Pointer(float32), float32, [16129], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, 127) {
    for (j.outer.init: int32, 0, 2) {
      for (j.inner.init.s: int32, 0, 64) {
        if @tir.likely((((j.outer.init*64) + j.inner.init.s) < 127), dtype=bool) {
          C[(((i*127) + (j.outer.init*64)) + j.inner.init.s)] = 0f32
        }
      }
    }
    for (k: int32, 0, 127) {
      for (j.outer: int32, 0, 2) {
        for (j.inner.s: int32, 0, 64) {
          if @tir.likely((((j.outer*64) + j.inner.s) < 127), dtype=bool) {
            let cse_var_3: int32 = (j.outer*64)
            let cse_var_2: int32 = (i*127)
            let cse_var_1: int32 = ((cse_var_2 + cse_var_3) + j.inner.s)
            C[cse_var_1] = (C[cse_var_1] + (A[(cse_var_2 + k)]*B[(((k*127) + cse_var_3) + j.inner.s)]))
          }
        }
      }
    }
  }
}

Seems no ramp primitive generated, the vectrization doesn’t take effect.

is the tir output as below?

with tvm.ir.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
    print(tvm.lower(s, [A, B, C], simple_mode=False))

@Lyken17 @wrongtest Thanks for your info.

I think the performance is not satisfied if [127, 127, 127] can not on-par with [128, 128, 128], because it is total feasible from HPC programmers’ perspective.

Apparently padding is one of the approaches to address the problem, but rather than padding the input explicitly, I perfer let TVM does it “on-the-fly” implicitly which is much more user-friendly and performance/memory efficient.

And my purpose is not to improve (127, 127, 127), this is just a case I use to demostrate that non-divisible dimension hasn’t been taken care of in TVM, which could hurt kernel performance generated by TVM in many cases.

In a word, I’m just wondering if there is a plan or ongoing work to address non-divisible split problem, in probably TVM primitive level?

1 Like

Yes. Code attached below:

import tvm
from tvm import te
import numpy as np


M = N = K = 127
dtype = "float32"
target = "llvm -mcpu=skylake-avx512"
device = tvm.device(target, 0)

A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
k = te.reduce_axis((0, K), name="k")

C = te.compute((M, N), lambda i, j: te.sum( A[i, k] * B[k, j], axis=k), name="C")

with tvm.ir.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
    s = te.create_schedule(C.op)

    m, n = s[C].op.axis
    print("m:", m)
    k_axis = s[C].op.reduce_axis[0]
    print("k_axis:", k_axis)


    s[C].reorder(m, k_axis, n)
    no, ni = s[C].split(n, factor=64)
    s[C].vectorize(ni)
    print(tvm.lower(s, [A, B, C], simple_mode=True))


with tvm.ir.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
    with tvm.target.create(target):
        func = tvm.build(s, [A, B, C])
    #print(func.get_source("asm"))

evaluator = func.time_evaluator(func.entry_name, device, number=100)

a = tvm.nd.array(np.random.rand(M, K).astype(dtype), device)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), device)
c = tvm.nd.array(np.random.rand(M, N).astype(dtype), device)

print("time: %f ms, GFLOPS: %f" % (evaluator(a, b, c).mean * 1000, 2 * M * N * K / evaluator(a, b, c).mean / 1e9))

I think simple_mode=True will disable some of the lowering passes, maybe try unset it?

As far as I know, there are some ways our team had tried on non-perfect tiles:

  • Leave it as it is, but partition the loops such that each partition can be condition free and easy to be vectorized (or other optimzations). Hints on the loop vars can be annotated to mark that there maybe conditions on the var to affects vectorization since https://github.com/apache/tvm/pull/9121

  • In the TensorIR fashion, you can write your own schedule to align the loop extents to divisible value or any other things, as long as the original semantics are preserved.

  • Relax the condition the vectorizer check, certain ir pattern (including condition) is mapped back to scalar form even if vectorize flag is explicitly annotated on the loop. We may relax some of these forms as long as the target can effectly handle them.

BTW, the IR I get on simple_mode=False is

m: iter_var(i, range(min=0, ext=127))
k_axis: iter_var(k, range(min=0, ext=127))
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [16129], []),
             B: Buffer(B_2: Pointer(float32), float32, [16129], []),
             A: Buffer(A_2: Pointer(float32), float32, [16129], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i: int32, 0, 127) {
    let cse_var_1: int32 = (i*127)
     {
      C[ramp(cse_var_1, 1, 64)] = broadcast(0f32, 64)
      C[ramp((cse_var_1 + 64), 1, 63)] = broadcast(0f32, 63)
      for (k: int32, 0, 127) {
        let cse_var_4: int32 = (k*127)
        let cse_var_3: int32 = (cse_var_1 + k)
        let cse_var_2: int32 = (cse_var_1 + 64)
         {
          C[ramp(cse_var_1, 1, 64)] = (C[ramp(cse_var_1, 1, 64)] + (broadcast(A[cse_var_3], 64)*B[ramp(cse_var_4, 1, 64)]))
          C[ramp(cse_var_2, 1, 63)] = (C[ramp(cse_var_2, 1, 63)] + (broadcast(A[cse_var_3], 63)*B[ramp((cse_var_4 + 64), 1, 63)]))
        }
      }
    }
  }
}

“simple_mode=False” works and TIR generated with “ramp” now. But performance didn’t get better on my server, because apparantly the way we print TIR with “tvm.lower()” doesn’t actually affect TVM’s build process.

But thank you for the advices, very helpful, I will try them all later. And it’s also worth noting here, the way we propose to handle non-perfect splits should also include the change of autotvm/autoscheduler, which could tackle non-perfect tiles to expand search spaces.