[MetaSchedule] [TIR] Error message: The primitive requires all the producer(s) of the given block to be present under the target loop

Hi TVM Community,

I’m trying to use meta_schedule.relay_integration.tune_relay to tune my custom relay OP. Here’s a snippet of my Tensor Expression, just for example:

d = lambda x: data[tir.if_then_else(x<0, 0, x)]

d1 = te.compute(data.shape, lambda x: te.sum(d(x), te.reduce_axis((0, 3))), "d1")
d2 = te.compute(data.shape, lambda x: te.sum(d(x), te.reduce_axis((0, 3))), "d2")

out = te.compute(data.shape, lambda x: d1[x] + d2[x], name="out")

But, I’m running into this error message:

......
for x in range(T.int64(30)):
    for rv in range(T.int64(3)):
        # tir.Block#0
        with T.block("d1"):
        ^^^^^^^^^^^^^^^^^^^
            v_x = T.axis.spatial(T.int64(30), x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            v_rv = T.axis.reduce(T.int64(3), rv)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            T.reads(p0[T.int64(0), T.int64(0):T.int64(20)])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            T.writes(d1[v_x])
            ^^^^^^^^^^^^^^^^^
            with T.init():
            ^^^^^^^^^^^^^^
                d1[v_x] = T.uint16(0)
                ^^^^^^^^^^^^^^^^^^^^^
            d1[v_x] = d1[v_x] + p0[T.int64(0), T.if_then_else(v_x < T.int64(0), T.int64(0), v_x)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for x_0 in range(T.int64(6)):
    for x_1 in range(T.int64(5)):
        for rv_0 in range(T.int64(1)):
            for x_2 in range(T.int64(1)):
                for rv_1 in range(T.int64(3)):
                    for x_3 in range(T.int64(1)):
                        # tir.Block#1
                        with T.block("d2"):
                        ^^^^^^^^^^^^^^^^^^^
                            v_x = T.axis.spatial(T.int64(30), x_0 * T.int64(5) + x_1 + x_2 + x_3)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v_rv = T.axis.reduce(T.int64(3), rv_0 * T.int64(3) + rv_1)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.reads(p0[T.int64(0), T.int64(0):T.int64(20)])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.writes(d2[v_x])
                            ^^^^^^^^^^^^^^^^^
                            T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            with T.init():
                            ^^^^^^^^^^^^^^
                                d2[v_x] = T.uint16(0)
                                ^^^^^^^^^^^^^^^^^^^^^
                            d2[v_x] = d2[v_x] + p0[T.int64(0), T.if_then_else(v_x < T.int64(0), T.int64(0), v_x)]
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
Error message: The primitive requires all the producer(s) of the given block to be present under the target loop. However, there are 1 producer(s) not satisfying the constraint. List of the producer(s):tir.Block#0tir.Block#1

Can anyone explain why I’m getting this error, and is there a way to fix it?

I also tried copying the TIR and running the TVMScript provided below, but it gave me the same error:

from tvm import meta_schedule as ms
from tvm.script import ir as I
from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(var_p0: T.handle, var_out: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        p0 = T.match_buffer(var_p0, (T.int64(30)), "uint8")
        out = T.match_buffer(var_out, (T.int64(30)), "uint8")
        with T.block("root"):
            T.reads()
            T.writes()
            d1 = T.alloc_buffer((T.int64(30)), "uint8")
            d2 = T.alloc_buffer((T.int64(30)), "uint8")
            for x in range(T.int64(30)):
                for rv in range(T.int64(3)):
                    # tir.Block#0
                    with T.block("d1"):
                        v_x = T.axis.spatial(T.int64(30), x)
                        v_rv = T.axis.reduce(T.int64(3), rv)
                        T.reads(p0[T.int64(0):T.int64(30)])
                        T.writes(d1[v_x])
                        with T.init():
                            d1[v_x] = T.uint16(0)
                        d1[v_x] = d1[v_x] + p0[v_x]
            for x in range(T.int64(30)):
                for rv in range(T.int64(3)):
                    # tir.Block#1
                    with T.block("d2"):
                        v_x = T.axis.spatial(T.int64(30), x)
                        v_rv = T.axis.reduce(T.int64(3), rv)
                        T.reads(p0[T.int64(0):T.int64(30)])
                        T.writes(d2[v_x])
                        with T.init():
                            d2[v_x] = T.uint16(0)
                        d2[v_x] = d2[v_x] + p0[v_x]
            for i in range(T.int64(30)):
                with T.block("out"):
                    v_i = T.axis.spatial(T.int64(30), i)
                    T.reads(d1[v_i], d2[v_i])
                    T.writes(out[v_i])
                    out[v_i] = T.Cast("uint8", d1[v_i] + d2[v_i])                                 


db = ms.tir_integration.tune_tir(Module, "llvm -num-cores=1", "./meta_schedule", 64)
sch = ms.tir_integration.compile_tir(db, Module, "llvm -num-cores=1")

I noticed that removing the unused reduce axis allowed successful tuning. So, I guess that the scheduler might be trying to merge the two loops automatically but is facing issues due to the reduce axis.

I’m not entirely sure if my guess is correct. If it’s not, could someone provide some insights or ideas? Thanks!

Meta-schedule does not support two reductions in one function

I’m quite curious about the meaning of “one function” mentioned here.

There is a TVMScript that cannot be optimized using meta-schedule:

@I.ir_module
class Module:
    @T.prim_func
    def main(dd: T.Buffer((30,), "float32"), out: T.Buffer((30,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        d1 = T.alloc_buffer((30,))
        d2 = T.alloc_buffer((30,))
        for x, rv in T.grid(30, 3):
            with T.block("d1"):
                v_x, v_rv = T.axis.remap("SR", [x, rv])
                T.reads(dd[0:30])
                T.writes(d1[v_x])
                with T.init():
                    d1[v_x] = T.float32(0)
                d1[v_x] = d1[v_x] + dd[v_x]
        for x, rv in T.grid(30, 3):
            with T.block("d2"):
                v_x, v_rv = T.axis.remap("SR", [x, rv])
                T.reads(dd[0:30])
                T.writes(d2[v_x])
                with T.init():
                    d2[v_x] = T.float32(0)
                d2[v_x] = d2[v_x] + dd[v_x]
        for x in range(30):
            with T.block("out"):
                v_x = T.axis.spatial(30, x)
                T.reads(d1[v_x], d2[v_x])
                T.writes(out[v_x])
                out[v_x] = d1[v_x] + d2[v_x]

But if I use T.reads(dd[v_x]) instead of T.reads(dd[0:30]), It will pass.

@I.ir_module
class Module:
    @T.prim_func
    def main(dd: T.Buffer((30,), "float32"), out: T.Buffer((30,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        d1 = T.alloc_buffer((30,))
        d2 = T.alloc_buffer((30,))
        for x, rv in T.grid(30, 3):
            with T.block("d1"):
                v_x, v_rv = T.axis.remap("SR", [x, rv])
                T.reads(dd[v_x])
                T.writes(d1[v_x])
                with T.init():
                    d1[v_x] = T.float32(0)
                d1[v_x] = d1[v_x] + dd[v_x]
        for x, rv in T.grid(30, 3):
            with T.block("d2"):
                v_x, v_rv = T.axis.remap("SR", [x, rv])
                T.reads(dd[v_x])
                T.writes(d2[v_x])
                with T.init():
                    d2[v_x] = T.float32(0)
                d2[v_x] = d2[v_x] + dd[v_x]
        for x in range(30):
            with T.block("out"):
                v_x = T.axis.spatial(30, x)
                T.reads(d1[v_x], d2[v_x])
                T.writes(out[v_x])
                out[v_x] = d1[v_x] + d2[v_x]

However, this approach still seems like what you mentioned ‘two reductions in one function’. Could you please explain the difference between these two syntaxes?