Hello, I have a question of using vthread.
By using code from Simple Matrix Multiply Tutorial, I want to make a code that uses vthread.
I’m using default configuration, BATCH = 1, BLOCK_IN = 16, BLOCK_OUT = 16.
I tried to use example of (116) * (1632) = (1*32) mutliplication so that vthread can be used.
(Omitted RPC connection part)
# Output channel factor m
m = 2
# Input channel factor n
n = 1
# Batch factor o (we use single batch inference)
o = 1
# A placeholder tensor in tiled data format
A = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="A", dtype=env.inp_dtype)
# B placeholder tensor in tiled data format
B = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="B", dtype=env.wgt_dtype)
# A copy buffer
A_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: A(*i), "A_buf")
# B copy buffer
B_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: B(*i), "B_buf")
# Outer input feature reduction axis
ko = tvm.reduce_axis((0, n), name="ko")
# Inner input feature reduction axis
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
# Describe the in-VTA matrix multiplication
C_buf = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci:
tvm.sum(A_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
B_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="C_buf")
# Cast to output type, and send to main memory
C = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: C_buf(*i).astype(env.inp_dtype),
name="C")
# Let's take a look at the generated schedule
s = tvm.create_schedule(C.op)
i0, i1, i2, i3 = s[C].op.axis
# VTA only supports 2 virtual threads
v_threads = 2
# Perform virtual thread split along output channel outer axis
# This is a lowered schedule without applying v_threads.
# As there are i1 and i3, I splitted i1 to tx and binded to cthread
# produce C {
# for (i1, 0, 2) {
# for (i3, 0, 16) {
# C[((i1*16) + i3)] = int8(C_buf[((i1*16) + i3)])
# }
# }
# }
_, tx = s[C].split(i1, factor=v_threads)
s[C].reorder(tx, i3)
s[C].bind(tx, tvm.thread_axis("cthread"))
print(tvm.lower(s, [A, B, C], simple_mode=True))
# Set the intermediate tensor's scope to VTA's on-chip buffers
s[A_buf].set_scope(env.inp_scope)
s[B_buf].set_scope(env.wgt_scope)
s[C_buf].set_scope(env.acc_scope)
# Move buffer copy into matrix multiply loop
s[A_buf].compute_at(s[C_buf], ko)
s[B_buf].compute_at(s[C_buf], ko)
# Tag the buffer copies with the DMA pragma to insert a DMA transfer
s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy)
s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy)
s[C].pragma(s[C].op.axis[0], env.dma_copy)
# Let's take a look at the transformed schedule
print(tvm.lower(s, [A, B, C], simple_mode=True))
s[C_buf].reorder(
ko,
s[C_buf].op.axis[0],
s[C_buf].op.axis[1],
s[C_buf].op.axis[2],
s[C_buf].op.axis[3],
ki)
s[C_buf].tensorize(s[C_buf].op.axis[2], env.gemm)
# Build GEMM VTA kernel
with vta.build_config(debug_flag=0x6):
my_gemm = vta.build(s, [A, B, C], "ext_dev", env.target_host, name="my_gemm")
The build isn’t done because of some pattern issues.
This is my first time using tensor expression language, so it’s kind of tough to know what is wrong in this code.
Can anyone give me some advice?
Thank you.
-jwlee