Hi Pei,
IMO, after InferRootBound step, the root iter vars of the current producer stage may change, because all the consumers requested a different range of each dim.
For example, here we split the axis of z_global.
import tvm
from tvm import te
n = 16
factor = 3
x = te.placeholder((n,), name="vx")
y = te.placeholder((n,), name="vy")
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
s = te.create_schedule(z.op)
z_global = s.cache_write(z, "global")
_, _ = s[z_global].split(z_global.op.axis[0], factor=factor)
tvm.lower(s, [x, y, z])
# IR
// attr [compute(z.global, body=[(vx[i.c] + vy[i.c])], axis=[iter_var(i.c, range(min=0, ext=16))], reduce_axis=[], tag=, attrs={})] realize_scope = "global"
producer_realize z.global([0, 16]) {
for (i.c.outer, 0, 6) {
for (i.c.inner, 0, 3) {
if (tir.likely(((i.c.inner + (i.c.outer*3)) < 16))) {
z.global[(i.c.inner + (i.c.outer*3))] =(vx[(i.c.inner + (i.c.outer*3))] + vy[(i.c.inner + (i.c.outer*3))])
}
}
}
// attr [compute(z, body=[(vx[i] + vy[i])], axis=[iter_var(i, range(min=0, ext=16))], reduce_axis=[], tag=, attrs={})] realize_scope = ""
producer_realize z([0, 16]) {
for (i, 0, 16) {
z[i] =z.global[i]
}
}
}
As we can see, for stage z.global, the consumer z requests the whole range in the dim i. Imaging the complicated cases (compute_at + split), different consumers may request different ranges per dim.
tvm/src/te/operation/compute_op.cc at b7e0cfb6d469c3745ae2195908daadea9c64d87e · apache/tvm · GitHub
Then, the next step is to unite all the request bounds.
tvm/src/te/operation/compute_op.cc at b7e0cfb6d469c3745ae2195908daadea9c64d87e · apache/tvm · GitHub
For now, the range per dim of the current producer stage may change. So we need another step, PassDownDomain to propagate the bounds of root to leaf.