when compute_at after fused & split, the IR is not as we want,the demo is like this:
import tvm
shape = [1024, 30522]
dtype = 'float16'
power_num = tvm.const(0.5, dtype = dtype)
data = tvm.placeholder(shape, name="data", dtype=dtype)
vlog_t = tvm.compute(shape, lambda *indice: tvm.log(data(*indice)), name = "vlog_t")
vmuls_t = tvm.compute(shape, lambda *indice: vlog_t(*indice) * power_num, name = "vmuls_t")
vexp_t = tvm.compute(shape, lambda *indice: tvm.exp(vmuls_t(*indice)), name = "vexp_t")
s = tvm.create_schedule(vexp_t.op)
print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")
exp_fused_axis = s[vexp_t].fuse(*vexp_t.op.axis)
muls_fused_axis = s[vmuls_t].fuse(*vmuls_t.op.axis)
log_fused_axis = s[vlog_t].fuse(*vlog_t.op.axis)
print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")
factor = 2048
xo, xi = s[vexp_t].split(exp_fused_axis, factor=factor)
mo, mi = s[vmuls_t].split(muls_fused_axis, factor)
lo, li = s[vlog_t].split(log_fused_axis, factor)
print(tvm.lower(s, [data, vexp_t], simple_mode=True))
s[vmuls_t].compute_at(s[vexp_t], xo)
s[vlog_t].compute_at(s[vexp_t], xo)
print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")
the origin IR:
/ attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
for (i0, 0, 1024) {
for (i1, 0, 30522) {
vlog_t[((i0*30522) + i1)] = log(data[((i0*30522) + i1)])
}
}
}
produce vmuls_t {
for (i0, 0, 1024) {
for (i1, 0, 30522) {
vlog_t[((i0*30522) + i1)] = (vlog_t[((i0*30522) + i1)]*0.500000h)
}
}
}
produce vexp_t {
for (i0, 0, 1024) {
for (i1, 0, 30522) {
vexp_t[((i0*30522) + i1)] = exp(vlog_t[((i0*30522) + i1)])
}
}
}
the fused IR:
// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
for (i0.i1.fused, 0, 31254528) {
vlog_t[i0.i1.fused] = log(data[i0.i1.fused])
}
}
produce vmuls_t {
for (i0.i1.fused, 0, 31254528) {
vlog_t[i0.i1.fused] = (vlog_t[i0.i1.fused]*0.500000h)
}
}
produce vexp_t {
for (i0.i1.fused, 0, 31254528) {
vexp_t[i0.i1.fused] = exp(vlog_t[i0.i1.fused])
}
}
the fuse and split IR:
// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
for (i0.i1.fused.outer, 0, 15261) {
for (i0.i1.fused.inner, 0, 2048) {
vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
}
}
}
produce vmuls_t {
for (i0.i1.fused.outer, 0, 15261) {
for (i0.i1.fused.inner, 0, 2048) {
vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
}
}
}
produce vexp_t {
for (i0.i1.fused.outer, 0, 15261) {
for (i0.i1.fused.inner, 0, 2048) {
vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
}
}
}
the compute at after fused and split:
// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vexp_t {
for (i0.i1.fused.outer, 0, 15261) {
produce vlog_t {
for (i0.i1.fused.outer, 0, 15261) {
for (i0.i1.fused.inner, 0, 2048) {
vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
}
}
}
produce vmuls_t {
for (i0.i1.fused.outer, 0, 15261) {
for (i0.i1.fused.inner, 0, 2048) {
vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
}
}
}
for (i0.i1.fused.inner, 0, 2048) {
vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
}
}
}