Hi, I am trying to define more complex TE compute functions, such as reduction with an initializer tensor, or argmin.
I ran into the issue, that te.create_prim_func
can not handle the recurrence of variables inside
te.comm_reducer
's fcombine
, fidentity
functions and the expression itself, like the axes of the reduction, or the compute lambda variables.
However, I found that tvm.lower
can handle these cases fine, with a te.schedule
.
Although the resulting function has more alloctions
My questions are:
- Is this a bug with
te.create_prim_func
? - Is there a workaround to not use the “legacy schedule”?(The attr is attached to the function if created with lower)
- Will
te.create_prim_func
support these use cases, or shouldtvm.lower
be used going forward even with Relax models?
e.g.: Sum reduction with an initializer tensor (bias) for simplicity
Python script
import tvm
from tvm import te
data = te.placeholder((5, 5), "float32", "a")
init = te.placeholder((5,), "float32", "b")
ax = te.reduce_axis((0, 5), "ax")
sum_red = te.compute(
(5,),
lambda i: te.comm_reducer(
lambda x, y: x + y,
lambda t: init[i], # `i` will be duplicated here, not recognized as the same as the lambda parameter
)(data[i, ax], axis=[ax]),
name="sum_red",
)
sch = te.create_schedule(sum_red.op)
tvm.lower(sch, [data, init, sum_red])["main"].show()
te.create_prim_func([data, init, sum_red]).show()
The resulting functions are as follows:
# From tvm.lower
# from tvm.script import tir as T
@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
for i in range(5):
sum_red[i] = b[i]
for ax in range(5):
a_1 = T.Buffer((25,), data=a.data)
sum_red[i] = sum_red[i] + a_1[i * 5 + ax]
# From te.create_prim_func
# from tvm.script import tir as T
@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, ax in T.grid(5, 5):
with T.block("sum_red"):
v_i, v_ax = T.axis.remap("SR", [i, ax])
T.reads(a[v_i, v_ax])
T.writes(sum_red[v_i])
with T.init():
i_1 = T.int32() # redefinition
sum_red[v_i] = b[i_1]
sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax]
As seen in line 15, the initializer i
variable is redefined as i_1
,
which then will throw an error during build as it has not been defined.