About compute_at()

Hi, I just started learning tvm and I have a small question

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
D = te.compute((m,), lambda i: C[i] + 3, name="D")

s = te.create_schedule([D.op])
s[C].compute_at(s[D], D.op.axis[0])
s[B].compute_at(s[C], C.op.axis[0])
print(tvm.lower(s, [A, B, C, D], simple_mode=True))

When I do the above, all of the loops combine into one as I desire like below:

  for (i: int32, 0, 10) {
    B[i] = (A[i] + 1f32)
    C[i] = (B[i]*2f32)
    D[i] = (C[i] + 3f32)
  }

But if i modify D = C + 3 into D = C + B like this:

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
D = te.compute((m,), lambda i: C[i] + B[i], name="D")

s = te.create_schedule([D.op])
s[C].compute_at(s[D], D.op.axis[0])
s[B].compute_at(s[C], C.op.axis[0])
print(tvm.lower(s, [A, B, C, D], simple_mode=True))

then I get an error saying:

Check failed: (found_attach || stage_attach.size() == 0) is false: Invalid Schedule, cannot find the producer compute(B, body=[(A[i] + 1f)], axis=[iter_var(i, range(min=0, ext=10))], reduce_axis=[], tag=, attrs={}) along the loop nest specified by compute_at of consumer compute(D, body=[(C[i] + B[i])], axis=[iter_var(i, range(min=0, ext=10))], reduce_axis=[], tag=, attrs={})

But what I actually want is something like below:

  for (i: int32, 0, 10) {
    B[i] = (A[i] + 1f32)
    C[i] = (B[i]*2f32)
    D[i] = (C[i] + B[i])
  }

is this possible? How can I achieve this?

Try modifying it like this

s[B].compute_at(s[D], D.op.axis[0])

Why didn’t I think of that! Thanks it works!