The adding index manipulation in te.sum seems not working with tvm 0.8.dev0.
n = 1024
dtype = "float32"
A = tvm.te.placeholder((n, n), dtype=dtype, name='A')
k = tvm.te.reduce_axis((0, 8), name='k')
B = tvm.te.compute((8,), lambda i: tvm.te.sum(A[i, k + i - (i % 8)], axis=k), name='B')
s = tvm.te.create_schedule(B.op)
print(tvm.lower(s, [A, B], simple_mode=True))
primfn(A_1: handle, B_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
B: Buffer(B_2: Pointer(float32), float32, [8], [])}
buffer_map = {A_1: A, B_1: B} {
for (i: int32, 0, 8) {
B_2[i] = 0f32
for (k: int32, 0, 8) {
B_2[i] = ((float32*)B_2[i] + (float32*)A_2[((i*1024) + k)])
}
}
}
Expected to
primfn(A_1: handle, B_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),
B: Buffer(B_2: Pointer(float32), float32, [8], [])}
buffer_map = {A_1: A, B_1: B} {
for (i: int32, 0, 8) {
B_2[i] = 0f32
for (k: int32, 0, 8) {
B_2[i] = ((float32*)B_2[i] + (float32*)A_2[((i*1024) + k + i - (i % 8))])
}
}
}
Is this kind of tiling allowed or error should generated here?