[Tensorize] Support conditions inside tensorize scope

Hi all,

Currently, we are working on the tensorization for some abstracted intrinsics. Once tensorized, another pass can do the HW intrinsics selection.

So, in the tensorize step, we do not care about any tail loop iterations (uneven split cases).

Here is a test case to do that.

import tvm
from tvm import te

def intrin_vadd(xo, m, n):
    x = te.placeholder((n,), name="vx")
    y = te.placeholder((n,), name="vy")
    if m % n == 0:
        body = lambda i: x[i] + y[i]
        body = lambda i: tvm.tir.Select(xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype))
    z = te.compute(x.shape, body, name="z")

    def intrin_func(ins, outs):
        xx, yy = ins
        zz = outs[0]
        return tvm.tir.call_packed("vadd", xx, yy, zz)

    buffer_params = {"offset_factor": 16}
    return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)

def add(m):
    x = te.placeholder((m,), name="x")
    y = te.placeholder((m,), name="y")
    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
    return x, y, z

def check_cache_write(m, factor):
    x, y, z = add(m)
    s = te.create_schedule(z.op)
    _, _ = s[z].split(z.op.axis[0], factor=factor)

    z_global = s.cache_write(z, "global")
    xo, xi = z_global.op.axis

    cond = xo * factor + xi < m 
    vadd = intrin_vadd(xo, m, factor)
    s[z_global].tensorize(xi, vadd)
    tvm.lower(s, [x, y, z]) 

check_cache_write(129, 16)

After splitting the axis of z, there will be a condition like xo * factor + xi < m to protect the computation. I follow the same IR to describe the intrinsic but got an unmatch error.

  File "~/apache/tvm/src/te/operation/tensorize.cc", line 336
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
  Check failed: expr_equal(lhs, rhs) == false: Failed to match the compute with TensorIntrin tensor_intrin's declaration  provided= select((((i.outer.c*16) + i) < 129), (vx[i] + vy[i]), 0f), intrin=  select((((i.outer.c*16) + i) < 129), (vx[i] + vy[i]), 0f)

The compute bodies are exactly the same, but the address of i.outer.c are not equal. I have done some investigation and found that after the schedule.normalize() pass, the outer loop var xo will be rebased and turned into a new IterVarNode with a new address.

ref: tvm/schedule_dataflow_rewrite.cc at main · apache/tvm (github.com)

Then, It is failed to check because the address of i.outer.c is not matched with the one in compute body.

As the opposite, compute op ran the substitution.

ref: tvm/compute_op.cc at main · apache/tvm (github.com) and tvm/compute_op.cc at main · apache/tvm (github.com)

Finally, I try to fix this issue by change this line tvm/tensorize.cc at main · apache/tvm (github.com) to

PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map));

and it works.

There is a similar issue as below.

All in all, this issue raised some questions.

  1. Does Tensorize has this constraint deliberately?
  2. Can we use Tensorize to cover the tail loop cases? We will process the tail part on the HW side or in another pass.

The fix [Tensorize] Support conds depend on outer loop vars inside tensorize scope by leeexyz · Pull Request #7497 · apache/tvm (github.com)