[VTA] Question on VTA matrix multiplication using vthread

Hello, I have a question of using vthread.

By using code from Simple Matrix Multiply Tutorial, I want to make a code that uses vthread.

I’m using default configuration, BATCH = 1, BLOCK_IN = 16, BLOCK_OUT = 16.

I tried to use example of (116) * (1632) = (1*32) mutliplication so that vthread can be used.

(Omitted RPC connection part)

# Output channel factor m 
m = 2
# Input channel factor n 
n = 1
# Batch factor o (we use single batch inference)
o = 1
# A placeholder tensor in tiled data format
A = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="A", dtype=env.inp_dtype)
# B placeholder tensor in tiled data format
B = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="B", dtype=env.wgt_dtype)
# A copy buffer
A_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: A(*i), "A_buf")
# B copy buffer
B_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: B(*i), "B_buf")

# Outer input feature reduction axis
ko = tvm.reduce_axis((0, n), name="ko")
# Inner input feature reduction axis
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
# Describe the in-VTA matrix multiplication
C_buf = tvm.compute(
    (o, m, env.BATCH, env.BLOCK_OUT),
    lambda bo, co, bi, ci:
        tvm.sum(A_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
                B_buf[co, ko, ci, ki].astype(env.acc_dtype),
                axis=[ko, ki]),
    name="C_buf")

# Cast to output type, and send to main memory
C = tvm.compute(
    (o, m, env.BATCH, env.BLOCK_OUT),
    lambda *i: C_buf(*i).astype(env.inp_dtype),
    name="C")

# Let's take a look at the generated schedule
s = tvm.create_schedule(C.op)

i0, i1, i2, i3 = s[C].op.axis

# VTA only supports 2 virtual threads
v_threads = 2

# Perform virtual thread split along output channel outer axis
# This is a lowered schedule without applying v_threads.
# As there are i1 and i3, I splitted i1 to tx and binded to cthread

# produce C {
#   for (i1, 0, 2) {
#     for (i3, 0, 16) {
#       C[((i1*16) + i3)] = int8(C_buf[((i1*16) + i3)])
#     }
#   }
# }

_, tx = s[C].split(i1, factor=v_threads)
s[C].reorder(tx, i3)
s[C].bind(tx, tvm.thread_axis("cthread"))

print(tvm.lower(s, [A, B, C], simple_mode=True))

# Set the intermediate tensor's scope to VTA's on-chip buffers
s[A_buf].set_scope(env.inp_scope)
s[B_buf].set_scope(env.wgt_scope)
s[C_buf].set_scope(env.acc_scope)

# Move buffer copy into matrix multiply loop
s[A_buf].compute_at(s[C_buf], ko)
s[B_buf].compute_at(s[C_buf], ko)

# Tag the buffer copies with the DMA pragma to insert a DMA transfer
s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy)
s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy)
s[C].pragma(s[C].op.axis[0], env.dma_copy)

# Let's take a look at the transformed schedule
print(tvm.lower(s, [A, B, C], simple_mode=True))

s[C_buf].reorder(
    ko,
    s[C_buf].op.axis[0],
    s[C_buf].op.axis[1],
    s[C_buf].op.axis[2],
    s[C_buf].op.axis[3],
    ki)

s[C_buf].tensorize(s[C_buf].op.axis[2], env.gemm)

# Build GEMM VTA kernel
with vta.build_config(debug_flag=0x6):
    my_gemm = vta.build(s, [A, B, C], "ext_dev", env.target_host, name="my_gemm")

The build isn’t done because of some pattern issues.

This is my first time using tensor expression language, so it’s kind of tough to know what is wrong in this code.

Can anyone give me some advice?

Thank you.

-jwlee

Hmmm that might be an issue with the IR pass that tries to pattern match 2D DMA access in this file (https://github.com/dmlc/tvm/blob/master/vta/python/vta/ir_pass.py#L324). Can you share the error that you’re getting?

1 Like

Sure, the following is the error.
Traceback (most recent call last):

  File "vthread_mm.py", line 444, in <module>
    my_gemm = vta.build(s, [A, B, C], "ext_dev", env.target_host, name="my_gemm")

  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/vta/python/vta/build_module.py", line 118, in build
    return tvm.build(*args, **kwargs)

  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/python/tvm/build_module.py", line 573, in build
    binds=binds)

  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/python/tvm/build_module.py", line 384, in lower
    stmt = f(stmt)

  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (1) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f8e2c2bca51]
  [bt] (0) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(+0xba4f4b) [0x7f8e2c2b7f4b]
  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/vta/python/vta/ir_pass.py", line 579, in inject_dma_intrin
    return tvm.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()
  [bt] (8) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(std::_Function_handler<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*), tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::set_dispatch<tvm::ir::ProducerConsumer>(std::function<tvm::Stmt (tvm::ir::ProducerConsumer const*, tvm::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*&&)+0x53) [0x7f8e2bdba473]
  [bt] (7) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(+0x69d364) [0x7f8e2bdb0364]
  [bt] (6) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate_(tvm::ir::ProducerConsumer const*, tvm::Stmt const&)+0x49) [0x7f8e2bdb3ba9]
  [bt] (5) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate(tvm::Stmt)+0x5b) [0x7f8e2bb8c53b]
  [bt] (4) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*) const+0x434) [0x7f8e2bb8c024]
  [bt] (3) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(std::_Function_handler<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*), tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::set_dispatch<tvm::ir::AttrStmt>(std::function<tvm::Stmt (tvm::ir::AttrStmt const*, tvm::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*&&)+0x53) [0x7f8e2bdba0f3]
  [bt] (2) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(+0x69d054) [0x7f8e2bdb0054]
  [bt] (1) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(tvm::ir::CopyIntrinInjector::Mutate_(tvm::ir::AttrStmt const*, tvm::Stmt const&)+0x2a1) [0x7f8e2bdd1781]
  [bt] (0) /home/jwlee/jwlee/up-to-date/with-zcu104/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f8e2bb2b4a2]
  File "/home/jwlee/jwlee/up-to-date/with-zcu104/tvm/src/pass/inject_copy_intrin.cc", line 50
TVMError: Check failed: MatchCopyPattern(op->body, &ret): Cannot match copy pattern of // attr [iter_var(cthread, , cthread)] virtual_thread = 2
for (i3, 0, 16) {
  C[((cthread*16) + i3)] = int8(C_buf[i3])
}

As you mentioned, the schedule after dma_copy isn’t generated as I imagined.

// attr [C_buf] storage_scope = "local.acc_buffer"
// attr [A_buf] storage_scope = "local.inp_buffer"
// attr [B_buf] storage_scope = "local.wgt_buffer"
produce C_buf {
  for (ci, 0, 16) {
    C_buf[ci] = 0
    produce A_buf {
      // attr [iter_var(i0, )] pragma_dma_copy = 1
      for (i3, 0, 16) {
        A_buf[i3] = A[i3]
      }
    }
    produce B_buf {
      // attr [iter_var(i0, )] pragma_dma_copy = 1
      for (i3, 0, 16) {
        B_buf[i3] = B[(((cthread*256) + (ci*16)) + i3)]
      }
    }
    for (ki, 0, 16) {
      C_buf[ci] = (C_buf[ci] + (int32(A_buf[ki])*int32(B_buf[ki])))
    }
  }
}
produce C {
  // attr [iter_var(i2, )] pragma_dma_copy = 1
  for (i3, 0, 16) {
    C[i3] = int8(C_buf[i3])
    C[(i3 + 16)] = int8(C_buf[i3])
  }
}

Produce C part shows C[(i3+16)] = int8(C_buf[i3]), which should be C[(i3+16)] = int8(C_buf[i3+16]).

I’m confused how to use pragma in my case.