@eqy Thanks for the reply.
I try this:
- input shape = (919,) #prime num
- pad input shape to (1024,)
- compute res_temp(abs)
- compute res
The code is shown below:
# compute
in_data = tvm.placeholder((919,), name="in_data")
pad_data = tvm.compute((1024,),
lambda i: tvm.select(i < 919, in_data[i], tvm.const(0, dtype=in_data.dtype)),
name="pad_data")
res_temp = tvm.compute((1024,),
lambda i: tvm.select(pad_data[i] >= 0, pad_data[i], -pad_data[i]),
name="res_temp")
res = tvm.compute((919,),
lambda i: res_temp[i],
name="res")
# schedule
sch = tvm.create_schedule(res.op)
# print(tvm.lower(sch, [in_data, res], simple_mode=True))
# cache_read cache_write
in_data_buf = sch.cache_read(in_data, param.scope_ubuf, [pad_data])
pad_data_buf = sch.cache_write(pad_data, param.scope_ubuf)
res_temp_buf = sch.cache_write(res_temp, param.scope_ubuf)
# compute_inline()
sch[pad_data].compute_inline()
sch[res_temp].compute_inline()
# print(tvm.lower(sch, [in_data, res], simple_mode=True))
# split, factor=128
factor=128
in_data_buf_xo, in_data_buf_xi = sch[in_data_buf].split(sch[in_data_buf].op.axis[0], factor=factor)
pad_data_buf_xo, pad_data_buf_xi = sch[pad_data_buf].split(sch[pad_data_buf].op.axis[0], factor=factor)
res_temp_buf_xo, res_temp_buf_xi = sch[res_temp_buf].split(sch[res_temp_buf].op.axis[0], factor=factor)
res_xo, res_xi = sch[res].split(sch[res].op.axis[0], factor=factor)
# compute_at
sch[res_temp_buf].compute_at(sch[res], res_xo)
sch[pad_data_buf].compute_at(sch[res], res_xo)
sch[in_data_buf].compute_at(sch[res], res_xo)
print(tvm.lower(sch, [in_data, res], simple_mode=True))
The output is:
// attr [in_data.local.UB] storage_scope = "local.UB"
allocate in_data.local.UB[float32 * 128]
custom_new { 0 }
custom_delete { nop(<args>); }
produce res {
for (i.outer, 0, 8) {
produce in_data.local.UB {
for (ax0.inner, 0, 128) {
if (likely(((i.outer*128) < (919 - ax0.inner)))) {
in_data.local.UB[ax0.inner] = in_data[((i.outer*128) + ax0.inner)]
}
}
}
produce pad_data.local.UB {
for (i.c.inner, 0, 128) {
in_data.local.UB[i.c.inner] = tvm_if_then_else(((i.outer*128) < (919 - i.c.inner)), in_data.local.UB[i.c.inner], 0.000000f)
}
}
produce res_temp.local.UB {
for (i.c.inner, 0, 128) {
in_data.local.UB[i.c.inner] = tvm_if_then_else((in_data.local.UB[i.c.inner] < 0.000000f), (in_data.local.UB[i.c.inner]*-1.000000f), in_data.local.UB[i.c.inner])
}
}
for (i.inner, 0, 128) {
if (likely(((i.outer*128) < (919 - i.inner)))) {
res[((i.outer*128) + i.inner)] = in_data.local.UB[i.inner]
}
}
}
}
Then tensorize pad_data.local.UB:
# tensorize
def intrin_pad(n, iterval):
shape = (n,)
in_data = tvm.placeholder(shape, name="data")
res = tvm.compute(shape, lambda i: tvm.select(i < (919-(iterval*128)), in_data[i], tvm.const(0, dtype=in_data.dtype)), name="res")
def intrin_func(ins, outs):
xx = ins[0]
yy = outs[0]
return tvm.call_packed("vpad", xx, yy)
with tvm.build_config(offset_factor=128):
return tvm.decl_tensor_intrin(res.op, intrin_func)
sch[pad_data_buf].tensorize(pad_data_buf_xi, intrin_pad(128, pad_data_buf_xo))
print(tvm.lower(sch, [in_data, res], simple_mode=True))
Error Message:
Traceback (most recent call last):
File “test.py”, line 67, in
print(tvm.lower(sch, [in_data, res], simple_mode=True))
File “/home/w00421387/repote/tensor_engine/python/te/tvm/build_module.py”, line 341, in lower
stmt = schedule.ScheduleOps(sch, bounds)
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/function.py”, line 280, in my_api_func
return flocal(args)
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/_ctypes/function.py”, line 183, in call
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File “/home/w00421387/repote/tensor_engine/python/te/tvm/_ffi/base.py”, line 66, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [11:21:05] /home/w00421387/repote/tensor_engine/src/op/tensorize.cc:357: Check failed: Equal(lhs, rhs) Failed to match the compute with TensorIntrin tensor_intrin’s declaration provided= select((i < (919 - (i.outer128))), data(i), 0.000000f), intrin= select((i < (919 - (i.c.outer*128))), data(i), 0.000000f)