compute_at after fused & split,the result is wrong

when compute_at after fused & split, the IR is not as we want,the demo is like this:

import tvm

shape = [1024, 30522]
dtype = 'float16'
power_num = tvm.const(0.5, dtype = dtype)
data = tvm.placeholder(shape, name="data", dtype=dtype)

vlog_t = tvm.compute(shape, lambda *indice: tvm.log(data(*indice)), name = "vlog_t")
vmuls_t = tvm.compute(shape, lambda *indice: vlog_t(*indice) * power_num, name = "vmuls_t")
vexp_t = tvm.compute(shape, lambda *indice: tvm.exp(vmuls_t(*indice)), name = "vexp_t")

s = tvm.create_schedule(vexp_t.op)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

exp_fused_axis = s[vexp_t].fuse(*vexp_t.op.axis)
muls_fused_axis = s[vmuls_t].fuse(*vmuls_t.op.axis)
log_fused_axis = s[vlog_t].fuse(*vlog_t.op.axis)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

factor = 2048
xo, xi = s[vexp_t].split(exp_fused_axis, factor=factor)
mo, mi = s[vmuls_t].split(muls_fused_axis, factor)
lo, li = s[vlog_t].split(log_fused_axis, factor)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))

s[vmuls_t].compute_at(s[vexp_t], xo)
s[vlog_t].compute_at(s[vexp_t], xo)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

the origin IR:

/ attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vlog_t[((i0*30522) + i1)] = log(data[((i0*30522) + i1)])
    }
  }
}
produce vmuls_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vlog_t[((i0*30522) + i1)] = (vlog_t[((i0*30522) + i1)]*0.500000h)
    }
  }
}
produce vexp_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vexp_t[((i0*30522) + i1)] = exp(vlog_t[((i0*30522) + i1)])
    }
  }
}

the fused IR:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0.i1.fused, 0, 31254528) {
    vlog_t[i0.i1.fused] = log(data[i0.i1.fused])
  }
}
produce vmuls_t {
  for (i0.i1.fused, 0, 31254528) {
    vlog_t[i0.i1.fused] = (vlog_t[i0.i1.fused]*0.500000h)
  }
}
produce vexp_t {
  for (i0.i1.fused, 0, 31254528) {
    vexp_t[i0.i1.fused] = exp(vlog_t[i0.i1.fused])
  }
}

the fuse and split IR:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}
produce vmuls_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
    }
  }
}
produce vexp_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}

the compute at after fused and split:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vexp_t {
  for (i0.i1.fused.outer, 0, 15261) {
    produce vlog_t {
      for (i0.i1.fused.outer, 0, 15261) {
        for (i0.i1.fused.inner, 0, 2048) {
          vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
        }
      }
    }
    produce vmuls_t {
      for (i0.i1.fused.outer, 0, 15261) {
        for (i0.i1.fused.inner, 0, 2048) {
          vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
        }
      }
    }
    for (i0.i1.fused.inner, 0, 2048) {
      vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}

We’ve merged a fix for this a couple weeks ago, so try a newer version and see if that’s better.
Here is the PR, if you’re interested: https://github.com/dmlc/tvm/pull/3073

This generated IR that you posted is not wrong, but inefficient. If you were to run it, it would produce correct results, but with a lot of unnecessary computation.