Fuse->split->compute_at issues

Hi all, I am trying to reshape a convolution to have it suitable for a tiled gemm. In order to do this, I fuse the width/height of the convolution, and then I split the dimension for the tile I want to compute. The particular code I am writing is quite convoluted, but I have this snippet that shows the main issue I am facing:

import tvm
from tvm import te, topi
m = 12
l = 16
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = topi.te.create_schedule(A2.op)
fused_axes_A1 = s[A1].fuse(A1.op.axis[0], A1.op.axis[1])
xo, xi = s[A1].split(fused_axes_A1, 4)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, 4)
s[A1].compute_at(s[A2], xo)  #comment on/off

Now, without the compute_at this is what the IR looks like:

for (i.j.fused.outer: int32, 0, 48) {
      for (i.j.fused.inner: int32, 0, 4) {
        A1[((i.j.fused.outer*4) + i.j.fused.inner)] = (float32*)A_2[((i.j.fused.outer*4) + i.j.fused.inner)]
      }
    }
    for (i.j.fused.outer_1: int32, 0, 48) {
      for (i.j.fused.inner_1: int32, 0, 4) {
        A2_2[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] = ((float32*)A1[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] + 3f32)
      }
    }

This is the IR I would expect to be produced when I turn compute_at on:

for (i.j.fused.outer_1: int32, 0, 48) {
      for (i.j.fused.inner: int32, 0, 4) {
        A1[((i.j.fused.outer_1*4) + i.j.fused.inner)] = (float32*)A_2[((i.j.fused.outer_1*4) + i.j.fused.inner)]
      }
      for (i.j.fused.inner_1: int32, 0, 4) {
        A2_2[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] = ((float32*)A1[((i.j.fused.outer_1*4) + i.j.fused.inner_1)] + 3f32)
      }
    }

But this is what I get:

for (i.j.fused.outer, 0, 48) {
    for (i.j.fused.outer_1: int32, 0, floordiv(((((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + 3), 4)) {
      for (i.j.fused.inner: int32, 0, 4) {
        if @tir.likely((floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) < ((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))), dtype=bool) {
          if @tir.likely((floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) < ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))), dtype=bool) {
            if @tir.likely((((i.j.fused.outer_1*4) + i.j.fused.inner) < (((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))), dtype=bool) {
              if @tir.likely((0 <= (floordiv(i.j.fused.outer, 4) + floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))), dtype=bool) {
                if @tir.likely(((floordiv(i.j.fused.outer, 4) + floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))) < 12), dtype=bool) {
                  if @tir.likely((0 <= ((floormod(i.j.fused.outer, 4)*4) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))), dtype=bool) {
                    if @tir.likely((((floormod(i.j.fused.outer, 4)*4) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))) < 16), dtype=bool) {
                      A1[((i.j.fused.outer_1*4) + i.j.fused.inner)] = (float32*)A_2[(((floordiv(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4)))*16) + (i.j.fused.outer*4)) + floormod(((i.j.fused.outer_1*4) + i.j.fused.inner), ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))))]
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
    for (i.j.fused.inner_1: int32, 0, 4) {
      A2_2[((i.j.fused.outer*4) + i.j.fused.inner_1)] = ((float32*)A1[((((floordiv(((i.j.fused.outer*4) + i.j.fused.inner_1), 16) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + floormod(((i.j.fused.outer*4) + i.j.fused.inner_1), 16)) - (floormod(i.j.fused.outer, 4)*4))] + 3f32)
    }
  }

If I don’t fuse/split the original computation (A1), things go a bit smoother but the output is still different from the one expected.

So I have two questions:

  • (main one) Is there something wrong with the expected output? Any reason for which, at least in theory, that is not what I should expect?
  • If there is nothing wrong with the expected output, why the compiler does not produce it? Is there anything that makes it complicated?

The example is taken (and slightly changed) from this PR: https://github.com/apache/incubator-tvm/pull/3073/ so I would cc also the people on the PR to get some insight from them.

Thanks, Giuseppe

cc @bas, @tqchen, @wweic, @anijain2305 , @ramana-arm

PS: It might be worth mentioning that in my case this just fails: basically this branch is taken and the computation either fails, or it takes few minutes to complete (since I am doing a lot of redundant computations)

How about we do compute_at first and then do the fuse with i & j?

After fusion of some iterators, TVM may lose some relation information on iterators. This seems a bug, but I’m not sure if it will be easy to track the exchange of iterator relation after fusion.

Hi @jcf94,

Doing compute_at first does not change things, unfortunately. However, if I don’t split/fuse the axis of A1.op at all, this is what I get:

  for (i.j.fused.outer, 0, 48) {
    for (i: int32, 0, ((floordiv(((i.j.fused.outer*4) + 3), 16) + 1) - floordiv(i.j.fused.outer, 4))) {
      if @tir.likely(((floordiv(i.j.fused.outer, 4) + i) < 12), dtype=bool) {
        for (j: int32, 0, ((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) {
          if @tir.likely((((floormod(i.j.fused.outer, 4)*4) + j) < 16), dtype=bool) {
            A1[((i*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + j)] = (float32*)A_2[(((i*16) + (i.j.fused.outer*4)) + j)]
          }
        }
      }
    }
    for (i.j.fused.inner: int32, 0, 4) {
      A2_2[((i.j.fused.outer*4) + i.j.fused.inner)] = ((float32*)A1[((((floordiv(((i.j.fused.outer*4) + i.j.fused.inner), 16) - floordiv(i.j.fused.outer, 4))*((floormod(((i.j.fused.outer*4) + 3), 16) + 1) - (floormod(i.j.fused.outer, 4)*4))) + floormod(((i.j.fused.outer*4) + i.j.fused.inner), 16)) - (floormod(i.j.fused.outer, 4)*4))] + 3f32)
    }
  }

As I previously mentioned, this is a bit better, but still not the ideal output. I am digging into bound.cc to understand more, but if you have any advice, please let me know!

Thanks,

Giuseppe

E … My bad, let me explain more clearly.

Try this kind of operations:

import tvm
from tvm import te
m = 12
l = 16
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = te.create_schedule(A2.op)

# Split first
i, j = A1.op.axis
jo, ji = s[A1].split(j, 4)
i, j = A2.op.axis
jo, ji = s[A2].split(j, 4)

# Compute at to get the right loop structure
s[A1].compute_at(s[A2], jo)

# Fuse the outer loop
s[A2].fuse(i, jo)

print(tvm.lower(s, [A, A2]))

In such manner, I get:

primfn(A_1: handle, A2_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {A2: Buffer(A2_2: handle, float32, [12, 16], []),
             A: Buffer(A_2: handle, float32, [12, 16], [])}
  buffer_map = {A_1: A, A2_1: A2} {
  attr [A1: handle] "storage_scope" = "global";
  allocate(A1, float32, [4]);
  for (i.j.outer.fused: int32, 0, 48) {
    for (j.inner: int32, 0, 4) {
      A1[j.inner] = (float32*)A_2[((i.j.outer.fused*4) + j.inner)])
    }
    for (j.inner_1: int32, 0, 4) {
      A2_2[((i.j.outer.fused*4) + j.inner_1)] = ((float32*)A1[j.inner_1]) + 3f32)
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data

I guess this’s what you want. :smiley:

Hi @jcf94, First of all, let me thank you so much for your help!

We are getting close to what I want :slight_smile:

Unfortunately, I don’t think fuse and split are interchangeable. Indeed, if l is not multiple of the splitting factor, if.likely conditions appear.

Indeed, if it were m=6 and l=10, when I fuse I get 60, and splitting in 4 is fine. The other way around , when I split l the axis is “extended” to 12 and if.likely conditions are inserted, and those are propagated when I compute_at -> fuse.

I am not sure there is a solution to that, because the information about the original l is somehow lost when I split.

Anyway, this surely answers my main question: what I expect in theory is the correct result, now we only have to find a way to make this happen

Ok, apparently is a known limitations: https://tvm.apache.org/docs/dev/inferbound.html#limitations-of-passupdomain