Hi, I’m trying to reuse the input reads because they’re expensive for my use case. Here’s a simple example:
in_size = 10
filter_size = 3
out_size = in_size - filter_size + 1
A = tvm.placeholder((in_size,), name = 'Input')
ra = tvm.reduce_axis((0, filter_size), name='ra')
Out = tvm.compute((out_size,), lambda x: tvm.sum(A[x + ra], axis=[ra]), name='Out')
s = tvm.create_schedule(Out.op)
AL = s.cache_read(A, "local", [Out])
s[AL].compute_at(s[Out], Out.op.axis[0])
print(tvm.lower(s, [A, Out], simple_mode=True))
Produces:
// attr [Input.local] storage_scope = "local"
allocate Input.local[float32 * 3]
produce Out {
for (x, 0, 8) {
produce Input.local {
for (ax0, 0, 3) {
Input.local[ax0] = Input[(x + ax0)] /* here there are 3 loads for each x iteration */
}
}
Out[x] = 0.000000f
for (ra, 0, 3) {
Out[x] = (Out[x] + Input.local[ra])
}
}
}
But I need something like:
// attr [Input.local] storage_scope = "local"
allocate Input.local[float32 * 3]
produce Out {
Input.local[0] = Input[0]
Input.local[1] = Input[1]
for (x, 0, 8) {
Input.local[2] = Input[(x + 2)] /* only one load inside the x loop */
Out[x] = 0.000000f
for (ra, 0, 3) {
Out[x] = (Out[x] + Input.local[ra])
}
Input.local[0] = Input.local[1]
Input.local[1] = Input.local[2] /* some way to permute the loaded local values or their pointers*/
}
}
I tried applying double buffering, but s[AL].double_buffer()
transforms the local buffer usage to Input.local[float32 * 2 * 3]
so that’s not good. s[AL].compute_at(s[Out], Out.op.reduce_axis[0]), s[AL].double_buffer()
will transform local buffer usage to Input.local[float32 * 2 * 1]
so that’s not good either.
Is there a way to generate the above schedule using the current TVM functions?