Te.create_prim_func vs tvm.lower

Hi, I am trying to define more complex TE compute functions, such as reduction with an initializer tensor, or argmin.

I ran into the issue, that te.create_prim_func can not handle the recurrence of variables inside te.comm_reducer's fcombine, fidentity functions and the expression itself, like the axes of the reduction, or the compute lambda variables.

However, I found that tvm.lower can handle these cases fine, with a te.schedule. Although the resulting function has more alloctions

My questions are:

  • Is this a bug with te.create_prim_func?
  • Is there a workaround to not use the “legacy schedule”?(The attr is attached to the function if created with lower)
  • Will te.create_prim_func support these use cases, or should tvm.lower be used going forward even with Relax models?
e.g.: Sum reduction with an initializer tensor (bias) for simplicity
Python script
import tvm
from tvm import te

data = te.placeholder((5, 5), "float32", "a")
init = te.placeholder((5,), "float32", "b")
ax = te.reduce_axis((0, 5), "ax")

sum_red = te.compute(
    (5,),
    lambda i: te.comm_reducer(
        lambda x, y: x + y,
        lambda t: init[i],	# `i` will be duplicated here, not recognized as the same as the lambda parameter
    )(data[i, ax], axis=[ax]),
    name="sum_red",
)

sch = te.create_schedule(sum_red.op)
tvm.lower(sch, [data, init, sum_red])["main"].show()

te.create_prim_func([data, init, sum_red]).show()

The resulting functions are as follows:

# From tvm.lower

# from tvm.script import tir as T

@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
    T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
    for i in range(5):
        sum_red[i] = b[i]
        for ax in range(5):
            a_1 = T.Buffer((25,), data=a.data)
            sum_red[i] = sum_red[i] + a_1[i * 5 + ax]
# From te.create_prim_func

# from tvm.script import tir as T

@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), b: T.Buffer((5,), "float32"), sum_red: T.Buffer((5,), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, ax in T.grid(5, 5):
        with T.block("sum_red"):
            v_i, v_ax = T.axis.remap("SR", [i, ax])
            T.reads(a[v_i, v_ax])
            T.writes(sum_red[v_i])
            with T.init():
                i_1 = T.int32()		# redefinition
                sum_red[v_i] = b[i_1]
            sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax]

As seen in line 15, the initializer i variable is redefined as i_1, which then will throw an error during build as it has not been defined.

Thanks for the finding! It should be a bug in te.create_prim_func and a quickfix is send in https://github.com/apache/tvm/pull/17301

1 Like

Thank you for the quick work.

Adding to your PR, is it possible to extend this fix to fcombine as well? As that has the same problem as the initializer.

For example for an argmin like function:

argmin = te.compute(
    (5,),
	lambda i: te.comm_reducer(
        lambda x, y: te.if_then_else(data[i, x] < y, x, ax),
        lambda t: te.const(0, dtype="int32")
	)(data[i, ax].astype("int32"), axis=[ax]),    # the cast is needed because out dtype is inferred from the expr
	name="argmin",
)

# resulting in:
#  	lower

@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), argmin: T.Buffer((5,), "int32")):
    T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
    for i in range(5):
        argmin[i] = 0
        for ax in range(5):
            cse_var_1: T.int32 = i * 5
            a_1 = T.Buffer((25,), data=a.data)
            argmin[i] = T.if_then_else(a_1[cse_var_1 + argmin[i]] < T.Cast("float32", T.Cast("int32", a_1[cse_var_1 + ax])), argmin[i], ax)


#  	create_prim_func

@T.prim_func
def main(a: T.Buffer((5, 5), "float32"), argmin: T.Buffer((5,), "int32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, ax in T.grid(5, 5):
        with T.block("argmin"):
            v_i, v_ax = T.axis.remap("SR", [i, ax])
            T.reads(a[v_i, v_ax])
            T.writes(argmin[v_i])
            with T.init():
                argmin[v_i] = 0
            i_1 = T.int32()
            ax_1 = T.int32()
            argmin[v_i] = T.if_then_else(a[i_1, argmin[v_i]] < T.Cast("float32", T.Cast("int32", a[v_i, v_ax])), argmin[v_i], ax_1)

got it. I would like take some look at more recursion issue in the creator implementation.

Applying the same method to value for ! n_buffer > 1 seems to solve the issue for the cases I tried, but I am out of my depth in the inner workings of TE, so further confirmation is necessary.

    for (int i = 0; i < n_buffers; ++i) {
      const Buffer& buffer = buffers[i];
//      init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
      PrimExpr identity =
          Substitute(info->transformer(reduce->combiner->identity_element[i]), var_map);
      init_stmts.push_back(BufferStore(buffer, identity, indices));
      PrimExpr value{nullptr};
      if (n_buffers > 1) {
        temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
        value = temp_vars.back();
      } else {
//        value = reduce->combiner.get()->operator()(lhs, rhs)[i];
        value = Substitute(info->transformer(reduce->combiner.get()->operator()(lhs, rhs)[i]),
                            var_map);
      }
      body_stmts.push_back(BufferStore(buffer, value, indices));
    }

Thanks! I follow this and ensure all fields from reducer are transformed. Do you have any more observed failure cases?

I could not find any other type of failure yet, only the lambda parameters’ or reduction axes’ duplication in fcombine and fidentity.