Cannot fuse two dependent stages with compute_at in softmax schedule

Hello,

In python/tvm/topi/x86/nn.py I’m trying to modify the schedule of softmax, I found that current schedule can reduce 1 for-loop for better perf:

orig s= @main = primfn(A_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [262144], [])}
  buffer_map = {A_1: A}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [256, 1024], [])} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [256]), storage_scope = global;
  allocate(T_softmax_exp: Pointer(global float32), float32, [262144]), storage_scope = global {
    for (i0: int32, 0, 256) {
      T_softmax_maxelem_1: Buffer(T_softmax_maxelem, float32, [256], [])[i0] = -3.40282e+38f32
      for (k: int32, 0, 1024) {
        T_softmax_maxelem_1[i0] = max(T_softmax_maxelem_1[i0], A[((i0*1024) + k)])
      }
    }
    for (i0_1: int32, 0, 256) {
      for (i1: int32, 0, 1024) {
        let cse_var_1: int32 = ((i0_1*1024) + i1)
        T_softmax_exp_1: Buffer(T_softmax_exp, float32, [262144], [])[cse_var_1] = @tir.exp((A[cse_var_1] - T_softmax_maxelem_1[i0_1]), dtype=float32)
      }
    }
    for (i0_2: int32, 0, 256) {
      T_softmax_maxelem_2: Buffer(T_softmax_maxelem, float32, [256], [])[i0_2] = 0f32
      for (k_1: int32, 0, 1024) {
        T_softmax_maxelem_2[i0_2] = (T_softmax_maxelem_2[i0_2] + T_softmax_exp_1[((i0_2*1024) + k_1)])
      }
    }
    for (i0_3: int32, 0, 256) {
      for (i1_1: int32, 0, 1024) {
        let cse_var_2: int32 = ((i0_3*1024) + i1_1)
        T_softmax_exp_2: Buffer(T_softmax_exp, float32, [262144], [])[cse_var_2] = (T_softmax_exp_1[cse_var_2] / T_softmax_maxelem_2[i0_3])
      }
    }
  }
}

I think it’s possible to fuse the two loops so that after the calculation of T_softmax_exp_1 it will be summed immediately. As there is dependency between those 2 stages, I added a compute_at() and this is when error occurs.

s[exp].compute_at(s[expsum], s[expsum].op.axis[0])

gives

Check failed: (found_attach || stage_attach.size() == 0) is false: Invalid Schedule, cannot find the producer compute(T_softmax_exp, body=[tir.exp((A[i0, i1] - T_softmax_maxelem[i0]))], axis=[iter_var(i0, range(min=0, ext=256)), iter_var(i1, range(min=0, ext=1024))], reduce_axis=[], tag=softmax_output, attrs={}) along the loop nest specified by compute_at of consumer compute(T_softmax_norm, body=[(T_softmax_exp[i0, i1]/T_softmax_expsum[i0])], axis=[iter_var(i0, range(min=0, ext=256)), iter_var(i1, range(min=0, ext=1024))], reduce_axis=[], tag=softmax_output, attrs={"axis": 1})

debug prints:

s[expsum]= stage(T_softmax_expsum, compute(T_softmax_expsum, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_softmax_exp[i0, k]], init=[], axis=[iter_var(k, range(min=0, ext=1024))], where=(bool)1, value_index=0)], axis=[iter_var(i0, range(min=0, ext=256))], reduce_axis=[iter_var(k, range(min=0, ext=1024))], tag=softmax_output, attrs={}))
s[exp]= stage(T_softmax_exp, compute(T_softmax_exp, body=[tir.exp((A[i0, i1] - T_softmax_maxelem[i0]))], axis=[iter_var(i0, range(min=0, ext=256)), iter_var(i1, range(min=0, ext=1024))], reduce_axis=[], tag=softmax_output, attrs={}))
s[softmax_op]= stage(T_softmax_norm, compute(T_softmax_norm, body=[(T_softmax_exp[i0, i1]/T_softmax_expsum[i0])], axis=[iter_var(i0, range(min=0, ext=256)), iter_var(i1, range(min=0, ext=1024))], reduce_axis=[], tag=softmax_output, attrs={"axis": 1}))