Hi all,
I defined a toy computation and scheduled it in TVM. I am having some difficulty in understanding how the lowered code that TVM produces corresponds to the schedule. I have reproduced both the Python and the lowered IR below.
Python code:
import tvm
from tvm import te
batch_size = 24
hidden_size = 256
length = 100
scan_state = te.placeholder((length, batch_size, hidden_size))
scan_init = te.placeholder((1, batch_size, hidden_size))
C = te.compute((length, batch_size, hidden_size), lambda t, b, i: scan_state[t - 1, b, i] * 4)
D = te.compute((length, batch_size, hidden_size), lambda t, b, i: C[t, b, i] + 7892)
scan_update = D
scan = tvm.te.scan(scan_init, scan_update, scan_state)
s = te.create_schedule([scan.op])
bx = te.thread_axis((0, batch_size), "blockIdx.x")
tx = te.thread_axis((0, hidden_size), "threadIdx.x")
s[scan].env_threads([bx])
xo, xi = s[C].split(s[C].op.axis[1], factor=24)
s[C].bind(xi, bx)
s[C].bind(s[C].op.axis[2], tx)
xo, xi = s[D].split(s[D].op.axis[1], factor=24)
s[D].bind(xi, bx)
s[D].bind(s[D].op.axis[2], tx)
print(tvm.lower(s, [scan_init, scan_state, scan_update, scan], simple_mode = True))
Lowered IR:
produce scan {
// attr [iter_var(blockIdx.x, range(min=0, ext=24), blockIdx.x)] thread_extent = 24
// attr [compute] storage_scope = "shared"
allocate compute[float32 * 256]
for (scan.idx, 0, 99) {
produce compute {
// attr [iter_var(threadIdx.x, range(min=0, ext=256), threadIdx.x)] thread_extent = 256
if (likely((blockIdx.x < 1))) {
if (likely((blockIdx.x < 12))) {
compute[((blockIdx.x*256) + threadIdx.x)] = (scan[(((scan.idx*6144) + (blockIdx.x*512)) + threadIdx.x)]*4f)
}
}
}
// attr [iter_var(threadIdx.x, range(min=0, ext=256), threadIdx.x)] thread_extent = 256
scan[((((scan.idx*6144) + (blockIdx.x*256)) + threadIdx.x) + 6144)] = (compute[threadIdx.x] + 7892f)
}
}
Specifically I do not understand the predicates (blockIdx.x < 1
and blockIdx.x < 12
) in the computation on the first compute op.