Understanding virtual threads

I am trying to understand the concept and use of virtual threads. The GPU convolution tutorial shows how to use vthreads to avoid shared memory conflict. But can I achieve the same goal w/o using vthreads?

The following simple test case uses tiling to get three levels of loop nesting. I have three versions of scheduling. In all three versions, the outermost level of loops are bound to blockIdx and the innermost ones are bound to threadIdx. The first version interchanges the middle level and innermost level, so that all the outer loops are bounded. The second version just leaves the middle level unbounded. The last version binds the middle level to virtual threads.

All three versions generate practically the same CUDA code.

‘virtual threads’ seems an important concept and tool in TVM. I can create a tutorial on this topic if you can help me understand it first :).

Thanks.

Yuan


import tvm

def show_cuda(s, A, B):
    ctx = tvm.context("cuda", 0)
    with tvm.build_config(dump_pass_ir=True) as cfg:
        func = tvm.build(s, [A, B], target="cuda", name='test')
    print(func.imported_modules[0].get_source())

M = 7*5*2
N = 9*3*2

A = tvm.placeholder((M,N), name='A')
B = tvm.compute((M,N), lambda i,j: 3.14 * A[i,j], name='B')

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
vthread_x = tvm.thread_axis("vthread", name="vx")
vthread_y = tvm.thread_axis("vthread", name="vy")
thread_x = tvm.thread_axis("threadIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")

#
# Schedule 1: Manual loop interchange
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#    [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#    for (i.inner.outer, 0, 2) {
#      for (j.inner.outer, 0, 2) {
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)
s[B].reorder(i_outer, j_outer, i_inner_inner, j_inner_inner, i_inner_outer, j_inner_outer)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)


#
# Schedule 2: No loop interchange
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    for (i.inner.outer, 0, 2) {
#      for (j.inner.outer, 0, 2) {
#        [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#        [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)


#
# Schedule 3: use virtual threads
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    [iter_var(vy, , vthread)] virtual_thread = 2
#    [iter_var(vx, , vthread)] virtual_thread = 2
#    [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#    [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_outer, vthread_y)
s[B].bind(j_inner_outer, vthread_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)
1 Like

vthread’s definition is quite straightforward: we create inner-most serial loops to simulate concurrent execution of the threads. Because vthread executes in the same thread, the vthread lowering will perform optimization to detect sharable computation among different vthread and only compute once.

Such compound effect is useful to create shared stridded access patterns such as those in gemm

Tianqi, thanks for your explanation. Would loop splitting with interchange not achieve the effect, as shown by ‘schedule 1’ in my sample code above?

1 Like

I was lurking this thread because I also had a question about vthreads (in the special case of the VTA).
In the VTA tec report there is a whole subsection about latency hiding using virtual threads.

Also checking the source code of inject_virtual_thread


Requires that the axis be labelled as “vthread”.

In the VTA tutorial

So the label is “cthread”

  1. Are all “cthreads” also virtual threads?
  2. Why is it in this case better to define this axis as “cthread” and not “vthread”?

Also, I have a question concerning this part of the VTA environment

  1. Can I interpret this as: “the coprocessor sync function for the VTA architecture is an external call to the VTASynchronize routine”?
  2. Why isn’t the VTASynchronize (also VTADepPush and VTADepPop) routine inlined in the output of the tvm.lower() routine?
// attr [res_conv] storage_scope = "local.acc_buffer"
// attr [data_buf] storage_scope = "local.inp_buffer"
// attr [kernel_buf] storage_scope = "local.wgt_buffer"
produce res {
/* There are a lot of lines here */
}
vta.coproc_sync()

hello,I am new to TVM, and I am learning it.
I build the kernel of your code, But I find the first two schedules get the same CUDA code:

extern "C" __global__ void test_kernel0( float* __restrict__ B,  float* __restrict__ A) {
  for (int i_inner_outer = 0; i_inner_outer < 2; ++i_inner_outer) {
    for (int j_inner_outer = 0; j_inner_outer < 2; ++j_inner_outer) {
      B[((((((((int)blockIdx.y) * 540) + (i_inner_outer * 270)) + (((int)threadIdx.y) * 54)) + (((int)blockIdx.x) * 6)) + (j_inner_outer * 3)) + ((int)threadIdx.x))] = (A[((((((((int)blockIdx.y) * 540) + (i_inner_outer * 270)) + (((int)threadIdx.y) * 54)) + (((int)blockIdx.x) * 6)) + (j_inner_outer * 3)) + ((int)threadIdx.x))] * 3.140000e+00f);
    }
  }
}

So it seems the reorder command does not have an impact on the CUDA code generated by TVM. I am confused, would you mind explaining what happened here? Thanks!

Hi, I am also confused about whether we can use reorder instead of a new concept virtual thread. The below two codes have the same results: (using virtual thread)

A = tvm.placeholder((8, ), name='A')
A1 = tvm.compute((8,), lambda i: A[i], name='A1')
s = tvm.create_schedule(A1.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx")
tx = tvm.thread_axis((0, 2), "threadIdx.x")
xo, xi = s[A1].split(A1.op.axis[0], nparts=2)
s[A1].bind(xo, vx)
xm, xi = s[A1].split(xi, nparts=2)
s[A1].bind(xm, tx)
tvm.lower(s,{A,A1},simple_mode=True)
#produce A1 {
#// attr [iter_var(threadIdx.x, Range(min=0, extent=2), threadIdx.x)] thread_extent = 2
#for (i.inner.inner, 0, 2) {
#  for (vx.s, 0, 2) {
#  A1[(((vx.s*4) + (threadIdx.x*2)) + i.inner.inner)] = A[(((vx.s*4) + (threadIdx.x*2)) + i.inner.inner)]
#}
#}
#} 

and (using reorder)

A = tvm.placeholder((8, ), name='A')
A1 = tvm.compute((8,), lambda i: A[i], name='A1')

s = tvm.create_schedule(A1.op)
tx = tvm.thread_axis((0, 2), "threadIdx.x")
xo, xi = s[A1].split(A1.op.axis[0], nparts=2)
xm, xi = s[A1].split(xi, nparts=2)
s[A1].bind(xm, tx)
s[A1].reorder(xm,xi,xo)
tvm.lower(s,{A,A1},simple_mode=True)
#produce A1 {
#// attr [iter_var(threadIdx.x, Range(min=0, extent=2), threadIdx.x)] thread_extent = 2
#for (i.inner.inner, 0, 2) {
#for (i.outer, 0, 2) {
#  A1[(((threadIdx.x*2) + i.inner.inner) + (i.outer*4))] = A[(((threadIdx.x*2) + i.inner.inner) + (i.outer*4))]
#}
#}
#}

Thanks!

Hi, I just have some new understanding about my previous questions: Virtual Thread can not be replaced by split+reorder, if we have complex memory hierarchy. When infer the bound, using virtual thread will ask TVM to infer the correspond bound by regarding it as thread (which are quite different compared with normal loop, see Function ‘NeedRelax’ for more detail).

3 Likes