Hi all, I am trying to reshape a convolution to have it suitable for a tiled gemm. In order to do this, I fuse the width/height of the convolution, and then I split the dimension for the tile I want to compute. The particular code I am writing is quite convoluted, but I have this snippet that shows the main issue I am facing:
import tvm
from tvm import te, topi
m = 12
l = 16
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = topi.te.create_schedule(A2.op)
fused_axes_A1 = s[A1].fuse(A1.op.axis[0], A1.op.axis[1])
xo, xi = s[A1].split(fused_axes_A1, 4)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, 4)
s[A1].compute_at(s[A2], xo) #comment on/off
Now, without the compute_at
this is what the IR looks like:
for (i.j.fused.outer: int32, 0, 48) {
for (i.j.fused.inner: int32, 0, 4) {
A1[((i.j.fused.outer*4) + i.j.fused.inner)] = (float32*)A_2[((i.j.fused.outer*4) + i.j.fused.inner)]
}
}
for (i.j.fused.outer_1: int32, 0, 48) {
for (i.j.fused.inner_1: int32, 0, 4) {
A2_2[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] = ((float32*)A1[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] + 3f32)
}
}
This is the IR I would expect to be produced when I turn compute_at
on:
for (i.j.fused.outer_1: int32, 0, 48) {
for (i.j.fused.inner: int32, 0, 4) {
A1[((i.j.fused.outer_1*4) + i.j.fused.inner)] = (float32*)A_2[((i.j.fused.outer_1*4) + i.j.fused.inner)]
}
for (i.j.fused.inner_1: int32, 0, 4) {
A2_2[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] = ((float32*)A1[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] + 3f32)
}
}
But this is what I get:
for (i.j.fused.outer, 0, 48) {
for (i.j.fused.outer_1: int32, 0, floordiv(((((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + 3), 4)) {
for (i.j.fused.inner: int32, 0, 4) {
if @tir.likely((floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) < ((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))), dtype=bool) {
if @tir.likely((floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) < ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))), dtype=bool) {
if @tir.likely((((i.j.fused.outer_1*4) + i.j.fused.inner) < (((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))), dtype=bool) {
if @tir.likely((0 <= (floordiv(i.j.fused.outer, 4) + floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))), dtype=bool) {
if @tir.likely(((floordiv(i.j.fused.outer, 4) + floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))) < 12), dtype=bool) {
if @tir.likely((0 <= ((floormod(i.j.fused.outer, 4)*4) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))), dtype=bool) {
if @tir.likely((((floormod(i.j.fused.outer, 4)*4) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))) < 16), dtype=bool) {
A1[((i.j.fused.outer_1*4) + i.j.fused.inner)] = (float32*)A_2[(((floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))*16) + (i.j.fused.outer*4)) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))]
}
}
}
}
}
}
}
}
}
for (i.j.fused.inner_1: int32, 0, 4) {
A2_2[((i.j.fused.outer*4) + i.j.fused.inner_1)] = ((float32*)A1[((((floordiv(((i.j.fused.outer*4) + i.j.fused.inner_1), 16) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + floormod(((i.j.fused.outer*4) + i.j.fused.inner_1), 16)) - (floormod(i.j.fused.outer, 4)*4))] + 3f32)
}
}
If I don’t fuse/split the original computation (A1
), things go a bit smoother but the output is still different from the one expected.
So I have two questions:
- (main one) Is there something wrong with the expected output? Any reason for which, at least in theory, that is not what I should expect?
- If there is nothing wrong with the expected output, why the compiler does not produce it? Is there anything that makes it complicated?
The example is taken (and slightly changed) from this PR: https://github.com/apache/incubator-tvm/pull/3073/ so I would cc also the people on the PR to get some insight from them.
Thanks, Giuseppe
cc @bas, @tqchen, @wweic, @anijain2305 , @ramana-arm
PS: It might be worth mentioning that in my case this just fails: basically this branch is taken and the computation either fails, or it takes few minutes to complete (since I am doing a lot of redundant computations)