Hi, I am also confused about whether we can use reorder instead of a new concept virtual thread
. The below two codes have the same results: (using virtual thread)
A = tvm.placeholder((8, ), name='A')
A1 = tvm.compute((8,), lambda i: A[i], name='A1')
s = tvm.create_schedule(A1.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx")
tx = tvm.thread_axis((0, 2), "threadIdx.x")
xo, xi = s[A1].split(A1.op.axis[0], nparts=2)
s[A1].bind(xo, vx)
xm, xi = s[A1].split(xi, nparts=2)
s[A1].bind(xm, tx)
tvm.lower(s,{A,A1},simple_mode=True)
#produce A1 {
#// attr [iter_var(threadIdx.x, Range(min=0, extent=2), threadIdx.x)] thread_extent = 2
#for (i.inner.inner, 0, 2) {
# for (vx.s, 0, 2) {
# A1[(((vx.s*4) + (threadIdx.x*2)) + i.inner.inner)] = A[(((vx.s*4) + (threadIdx.x*2)) + i.inner.inner)]
#}
#}
#}
and (using reorder)
A = tvm.placeholder((8, ), name='A')
A1 = tvm.compute((8,), lambda i: A[i], name='A1')
s = tvm.create_schedule(A1.op)
tx = tvm.thread_axis((0, 2), "threadIdx.x")
xo, xi = s[A1].split(A1.op.axis[0], nparts=2)
xm, xi = s[A1].split(xi, nparts=2)
s[A1].bind(xm, tx)
s[A1].reorder(xm,xi,xo)
tvm.lower(s,{A,A1},simple_mode=True)
#produce A1 {
#// attr [iter_var(threadIdx.x, Range(min=0, extent=2), threadIdx.x)] thread_extent = 2
#for (i.inner.inner, 0, 2) {
#for (i.outer, 0, 2) {
# A1[(((threadIdx.x*2) + i.inner.inner) + (i.outer*4))] = A[(((threadIdx.x*2) + i.inner.inner) + (i.outer*4))]
#}
#}
#}
Thanks!