How to compute_at a stage to a not-directed related stage

A = tvm.placeholder((64, 32), dtype=dtype, name="A")
C1 = tvm.compute((64, 32), lambda i, j: 7*A[i][j], name="C1")
C2 = tvm.compute((64, 64), lambda i, j: 5*A[i][j%32], name="C2")
s=create_schedule([C1.op,C2.op])

C1 and C2 are not a pair of consumer and producer, but obviousily, they have the same axis[0] derived from A.axis[0]. the fllowing IR is expected, how to generate the following IR ?

produce C2{
    for(i0,0,64){
        produce C1{
        	for(i1,0,32){
        		C1(i0,i1)=7*A(i0,i1)
	        }
	    }
        for(i2,0,32){
        	C2(i0,i2)=5*A(i0,i2)
        }
    }    
}

and we expected the interface in python to be

s[C1].compute_at(s[C2],s[C2].op.axis[0])

what key difficulties are approaching to us

I am facing a similar problem.

A = te.placeholder((64, 64), name="A")
B = te.placeholder((3, 3), name="B")

r0 = te.reduce_axis((0, 3), name="r0")
r1 = te.reduce_axis((0, 3), name="r1")
C1 = te.compute(
    (64, 64),
    lambda i, j: te.sum(A[i + r0, j + r1] * B[r0, r1], axis=[r0, r1]),
    name="C1",
)

r0 = te.reduce_axis((0, 3), name="r0")
r1 = te.reduce_axis((0, 3), name="r1")
C2 = te.compute(
    (64, 64),
    lambda i, j: te.sum(A[i + r0, j + r1], axis=[r0, r1]),
    name="C2",
)

C3 = te.compute(
    (64, 64), 
    lambda i, j: C1[i, j] * 2 + C2[i, j], 
    name="C3"
)

s = te.create_schedule(C3.op)
s[C1].compute_at(s[C3], C3.op.axis[1])
s[C2].compute_at(s[C3], C3.op.axis[1])
print(tvm.lower(s, [A, B, C3], simple_mode=True))

# s[C1].compute_at(s[C2], C2.op.reduce_axis[1])
# print(tvm.lower(s, [A, B, C3], simple_mode=True))   # fail

Out:

  for (i: int32, 0, 64) {
    for (j: int32, 0, 64) {
      C1_1: Buffer(C1, float32, [1], [], align=4)[0] = 0f32
      for (r0: int32, 0, 3) {
        for (r1: int32, 0, 3) {
          C1_1[0] = (C1_1[0] + (A_3: Buffer(A_2, float32, [4096], [])[((((i*64) + (r0*64)) + j) + r1)]*B_3: Buffer(B_2, float32, [9], [])[((r0*3) + r1)]))
        }
      }
      C2_1: Buffer(C2, float32, [1], [], align=4)[0] = 0f32
      for (r0_1: int32, 0, 3) {
        for (r1_1: int32, 0, 3) {
          C2_1[0] = (C2_1[0] + A_3[((((i*64) + (r0_1*64)) + j) + r1_1)])
        }
      }
      C3_3: Buffer(C3_2, float32, [4096], [])[((i*64) + j)] = ((C1_1[0]*2f32) + C2_1[0])
    }
  }

I want to further optimize the schedule by combining the two reduction loops, like:

  for (i: int32, 0, 64) {
    for (j: int32, 0, 64) {
      C1_1: Buffer(C1, float32, [1], [], align=4)[0] = 0f32
      C2_1: Buffer(C2, float32, [1], [], align=4)[0] = 0f32
      for (r0: int32, 0, 3) {
        for (r1: int32, 0, 3) {
          C1_1[0] = (C1_1[0] + (A_3: Buffer(A_2, float32, [4096], [])[((((i*64) + (r0*64)) + j) + r1)]*B_3: Buffer(B_2, float32, [9], [])[((r0*3) + r1)]))
          C2_1[0] = (C2_1[0] + A_3[((((i*64) + (r0_1*64)) + j) + r1_1)])
        }
      }
      C3_3: Buffer(C3_2, float32, [4096], [])[((i*64) + j)] = ((C1_1[0]*2f32) + C2_1[0])
    }
  }

How may I approach this?