I am trying to understand the concept and use of virtual threads. The GPU convolution tutorial shows how to use vthreads to avoid shared memory conflict. But can I achieve the same goal w/o using vthreads?
The following simple test case uses tiling to get three levels of loop nesting. I have three versions of scheduling. In all three versions, the outermost level of loops are bound to blockIdx and the innermost ones are bound to threadIdx. The first version interchanges the middle level and innermost level, so that all the outer loops are bounded. The second version just leaves the middle level unbounded. The last version binds the middle level to virtual threads.
All three versions generate practically the same CUDA code.
‘virtual threads’ seems an important concept and tool in TVM. I can create a tutorial on this topic if you can help me understand it first :).
Thanks.
Yuan
import tvm
def show_cuda(s, A, B):
ctx = tvm.context("cuda", 0)
with tvm.build_config(dump_pass_ir=True) as cfg:
func = tvm.build(s, [A, B], target="cuda", name='test')
print(func.imported_modules[0].get_source())
M = 7*5*2
N = 9*3*2
A = tvm.placeholder((M,N), name='A')
B = tvm.compute((M,N), lambda i,j: 3.14 * A[i,j], name='B')
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
vthread_x = tvm.thread_axis("vthread", name="vx")
vthread_y = tvm.thread_axis("vthread", name="vy")
thread_x = tvm.thread_axis("threadIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")
#
# Schedule 1: Manual loop interchange
#
# [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
# [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
# [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
# [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
# for (i.inner.outer, 0, 2) {
# for (j.inner.outer, 0, 2) {
#
s = tvm.create_schedule(B.op)
Mblock = 5*2
Nblock = 3*2
i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)
s[B].reorder(i_outer, j_outer, i_inner_inner, j_inner_inner, i_inner_outer, j_inner_outer)
s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)
show_cuda(s, A, B)
#
# Schedule 2: No loop interchange
#
# [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
# [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
# for (i.inner.outer, 0, 2) {
# for (j.inner.outer, 0, 2) {
# [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
# [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)
Mblock = 5*2
Nblock = 3*2
i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)
s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)
show_cuda(s, A, B)
#
# Schedule 3: use virtual threads
#
# [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
# [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
# [iter_var(vy, , vthread)] virtual_thread = 2
# [iter_var(vx, , vthread)] virtual_thread = 2
# [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
# [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)
Mblock = 5*2
Nblock = 3*2
i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)
s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_outer, vthread_y)
s[B].bind(j_inner_outer, vthread_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)
show_cuda(s, A, B)