how to add split condition inside tensorize scope

Hi all,

I need to split a axis which size is big prime number,then I tensorize at the xi iter. However I got something wrong, and I have no idea about that:
Traceback (most recent call last):
File “abs.py”, line 38, in
tf_abs((131,), “float16”)
File “abs.py”, line 34, in tf_abs
print(tvm.lower(sch, [in_data, res], simple_mode=True))
File “/home/laughing/soft/tvm/python/tvm/build_module.py”, line 341, in lower
stmt = schedule.ScheduleOps(sch, bounds)
File “/home/laughing/soft/tvm/python/tvm/_ffi/function.py”, line 280, in my_api_func
return flocal(args)
File “/home/laughing/soft/tvm/python/tvm/_ffi/_ctypes/function.py”, line 184, in call
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File “/home/laughing/soft/tvm/python/tvm/_ffi/base.py”, line 66, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [10:45:36] /home/laughing/soft/tvm/src/op/tensorize.cc:127: Tensorize failed, split condition likely(((i0.inner + (i0.outer128)) < 131)) relies on var defined inside tensorize scope
Thanks,

“”"
import tvm

def intrin_abs(n, dtype):
    x = tvm.placeholder((n,), dtype=dtype, name="x")
    y = tvm.compute((n,), lambda *indices: tvm.select(x(*indices) < 0, -x(*indices), x(*indices)))
    def intrin_func(ins, outs):
        xx = ins[0]
        yy = outs[0]
        return tvm.call_packed("vabs", xx, yy)
    with tvm.build_config(offset_factor=128):
        return tvm.decl_tensor_intrin(y.op, intrin_func)

def tf_abs(shape, dtype):
    in_data = tvm.placeholder(shape, dtype=dtype, name="in_data")
    res = tvm.compute(shape, lambda *indices: tvm.select(in_data(*indices) < 0, -in_data(*indices), in_data(*indices)))
    sch = tvm.create_schedule(res.op)
    xo, xi = sch[res].split(sch[res].op.axis[0], factor=128)
    sch[res].tensorize(xi, intrin_abs(128, in_data.dtype))
    print(tvm.lower(sch, [in_data, res], simple_mode=True))

if __name__ == "__main__":
    tf_abs((131,), "float16")

“”"

1 Like

Can you try check if this works when the split is even? My guess is that this is due to the generated branch to cover the tail loop iterations (when only a fraction is executed).

Unfortunately I do not know of any good solutions to supporting uneven splits without branches here. Can you afford the penalty of padding your input to ensure an even split?

@eqy Thanks for the reply.
I try this:

  1. input shape = (919,) #prime num
  2. pad input shape to (1024,)
  3. compute res_temp(abs)
  4. compute res

The code is shown below:

    # compute
    in_data = tvm.placeholder((919,), name="in_data")
    pad_data = tvm.compute((1024,), 
                           lambda i: tvm.select(i < 919, in_data[i], tvm.const(0, dtype=in_data.dtype)),
                           name="pad_data")
    res_temp = tvm.compute((1024,),
                           lambda i: tvm.select(pad_data[i] >= 0, pad_data[i], -pad_data[i]),
                           name="res_temp")
    res = tvm.compute((919,),
                      lambda i: res_temp[i],
                      name="res")

    # schedule
    sch = tvm.create_schedule(res.op)
    # print(tvm.lower(sch, [in_data, res], simple_mode=True))

    # cache_read cache_write
    in_data_buf = sch.cache_read(in_data, param.scope_ubuf, [pad_data])
    pad_data_buf = sch.cache_write(pad_data, param.scope_ubuf)
    res_temp_buf = sch.cache_write(res_temp, param.scope_ubuf)

    # compute_inline()
    sch[pad_data].compute_inline()
    sch[res_temp].compute_inline()
    # print(tvm.lower(sch, [in_data, res], simple_mode=True))

    # split, factor=128
    factor=128
    in_data_buf_xo, in_data_buf_xi = sch[in_data_buf].split(sch[in_data_buf].op.axis[0], factor=factor)
    pad_data_buf_xo, pad_data_buf_xi = sch[pad_data_buf].split(sch[pad_data_buf].op.axis[0], factor=factor)
    res_temp_buf_xo, res_temp_buf_xi = sch[res_temp_buf].split(sch[res_temp_buf].op.axis[0], factor=factor)
    res_xo, res_xi = sch[res].split(sch[res].op.axis[0], factor=factor)

    # compute_at
    sch[res_temp_buf].compute_at(sch[res], res_xo)
    sch[pad_data_buf].compute_at(sch[res], res_xo)
    sch[in_data_buf].compute_at(sch[res], res_xo)
    print(tvm.lower(sch, [in_data, res], simple_mode=True))

The output is:

    // attr [in_data.local.UB] storage_scope = "local.UB"
        allocate in_data.local.UB[float32 * 128]
         custom_new { 0 }
         custom_delete { nop(<args>); }
        produce res {
          for (i.outer, 0, 8) {
            produce in_data.local.UB {
              for (ax0.inner, 0, 128) {
                if (likely(((i.outer*128) < (919 - ax0.inner)))) {
                  in_data.local.UB[ax0.inner] = in_data[((i.outer*128) + ax0.inner)]
                }
              }
            }
            produce pad_data.local.UB {
              for (i.c.inner, 0, 128) {
                in_data.local.UB[i.c.inner] = tvm_if_then_else(((i.outer*128) < (919 - i.c.inner)), in_data.local.UB[i.c.inner], 0.000000f)
              }
            }
            produce res_temp.local.UB {
              for (i.c.inner, 0, 128) {
                in_data.local.UB[i.c.inner] = tvm_if_then_else((in_data.local.UB[i.c.inner] < 0.000000f), (in_data.local.UB[i.c.inner]*-1.000000f), in_data.local.UB[i.c.inner])
              }
            }
            for (i.inner, 0, 128) {
              if (likely(((i.outer*128) < (919 - i.inner)))) {
                res[((i.outer*128) + i.inner)] = in_data.local.UB[i.inner]
              }
            }
          }
        }

Then tensorize pad_data.local.UB:

    # tensorize
    def intrin_pad(n, iterval):
        shape = (n,)
        in_data = tvm.placeholder(shape, name="data")
        res = tvm.compute(shape, lambda i: tvm.select(i < (919-(iterval*128)), in_data[i], tvm.const(0, dtype=in_data.dtype)), name="res")
        def intrin_func(ins, outs):
            xx = ins[0]
            yy = outs[0]
            return tvm.call_packed("vpad", xx, yy)
        with tvm.build_config(offset_factor=128):
            return tvm.decl_tensor_intrin(res.op, intrin_func)

    sch[pad_data_buf].tensorize(pad_data_buf_xi, intrin_pad(128, pad_data_buf_xo))
    print(tvm.lower(sch, [in_data, res], simple_mode=True))

Error Message:
Traceback (most recent call last):
File “test.py”, line 67, in
print(tvm.lower(sch, [in_data, res], simple_mode=True))
File “/home/w00421387/repote/tensor_engine/python/te/tvm/build_module.py”, line 341, in lower
stmt = schedule.ScheduleOps(sch, bounds)
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/function.py”, line 280, in my_api_func
return flocal(args)
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/_ctypes/function.py”, line 183, in call
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/base.py”, line 66, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [11:21:05] /home/w00421387/repote/tensor_engine/src/op/tensorize.cc:357: Check failed: Equal(lhs, rhs) Failed to match the compute with TensorIntrin tensor_intrin’s declaration provided= select((i < (919 - (i.outer
128))), data(i), 0.000000f), intrin= select((i < (919 - (i.c.outer*128))), data(i), 0.000000f)