Hi, I am trying to define more complex TE compute functions, such as reduction with an initializer tensor, or argmin.

I ran into the issue, that `te.create_prim_func`

can not handle the recurrence of variables inside
`te.comm_reducer`

's `fcombine`

, `fidentity`

functions and the expression itself, like the axes of the reduction, or the compute lambda variables.

However, I found that `tvm.lower`

can handle these cases fine, with a `te.schedule`

.
Although the resulting function has more alloctions

#### My questions are:

- Is this a bug with
`te.create_prim_func`

? - Is there a workaround to not use the â€ślegacy scheduleâ€ť?(The attr is attached to the function if created with lower)
- Will
`te.create_prim_func`

support these use cases, or should`tvm.lower`

be used going forward even with Relax models?

##### e.g.: Sum reduction with an initializer tensor (bias) for simplicity

## Python script

```
import tvm
from tvm import te
data = te.placeholder((5, 5), "float32", "a")
init = te.placeholder((5,), "float32", "b")
ax = te.reduce_axis((0, 5), "ax")
sum_red = te.compute(
(5,),
lambda i: te.comm_reducer(
lambda x, y: x + y,
lambda t: init[i], # `i` will be duplicated here, not recognized as the same as the lambda parameter
)(data[i, ax], axis=[ax]),
name="sum_red",
)
sch = te.create_schedule(sum_red.op)
tvm.lower(sch, [data, init, sum_red])["main"].show()
te.create_prim_func([data, init, sum_red]).show()
```

The resulting functions are as follows:

```
# From tvm.lower
# from tvm.script import tir as T
@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
for i in range(5):
sum_red[i] = b[i]
for ax in range(5):
a_1 = T.Buffer((25,), data=a.data)
sum_red[i] = sum_red[i] + a_1[i * 5 + ax]
```

```
# From te.create_prim_func
# from tvm.script import tir as T
@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, ax in T.grid(5, 5):
with T.block("sum_red"):
v_i, v_ax = T.axis.remap("SR", [i, ax])
T.reads(a[v_i, v_ax])
T.writes(sum_red[v_i])
with T.init():
i_1 = T.int32() # redefinition
sum_red[v_i] = b[i_1]
sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax]
```

As seen in line 15, the initializer `i`

variable is redefined as `i_1`

,
which then will throw an error during build as it has not been defined.