tvm.thread_axis() accepts an optional domain argument. Could anyone please explain how the domain is used?
I tried the following example and do not see how it affects the schedule.
M = 7*5*2
N = 11*9*2
A = tvm.placeholder((M, N), name='A')
B = tvm.compute((M, N), lambda i,j: 3.14 * A[i, j], name='B')
s = tvm.create_schedule(B.op)
Mblock = 5*2
Nblock = 9*2
i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0,5), "threadIdx.x")
thread_y = tvm.thread_axis((0,9), "threadIdx.y")
s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner, thread_y)
s[B].bind(j_inner, thread_x)
Output
produce B {
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 11
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 10
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 18
B[(((((blockIdx.y*110) + blockIdx.x) + (threadIdx.y*11))*18) + threadIdx.x)] = (A[(((((blockIdx.y*110) + blockIdx.x) + (threadIdx.y*11))*18) + threadIdx.x)]*3.140000f)
}