Question about parallel primitives

Hi, I’m finishing practice about handwritten scheduling on tianqi’s machine learning compilation course. I found that the code could not pass the compilation after I changed the location of a parallel intrinsic.

@tvm.script.ir_module
class MyBmmRelu:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    Y = T.alloc_buffer([16, 128, 128], dtype="float32")
    for n, i, j, kk in T.grid(16, 128, 128, 128):
        with T.block("Y"):
            vn,vi,vj,vk =  T.axis.remap("SSSR", [n, i, j, kk])
            with T.init():
                Y[vn,vi,vj]=T.float32(0)
            Y[vn,vi,vj]=Y[vn,vi,vj]+A[vn,vi,vk]*B[vn,vk,vj] 
    for n,i,j in T.grid(16,128,128):        
        with T.block("C"):
            vn,vi,vj=T.axis.remap("SSS",[n,i,j])
            C[vn,vi,vj]=T.max(Y[vn,vi,vj],T.float32(0))
            #C[vn,vi,vj]=T.float32(0)

The following code could pass the compilation

sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
j0, j1 = sch.split(j, factors=[16, 8])
k0, k1 = sch.split(k, factors = [32, 4])
sch.reorder(i, j0, j1)

Y = sch.get_block("Y", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)

sch.parallel(n)
Y_init = sch.decompose_reduction(Y, j1)


sch.reorder(k0, k1, j1)

In, Ii, Ij0, Ij1 = sch.get_loops(Y_init)
sch.vectorize(Ij1)

Cn, Ci, Cj0, Cj1 = sch.get_loops(C)
sch.vectorize(Cj1)

#sch.parallel(Cn)

Yn, Yi, Yj0, Yk0, Yk1, Yj1 = sch.get_loops(Y)
sch.unroll(Yk1)

IPython.display.Code(sch.mod.script(), language="python")

Such could not(notice that the location of the parallel changed)

sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.

# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
j0, j1 = sch.split(j, factors=[16, 8])
k0, k1 = sch.split(k, factors = [32, 4])
sch.reorder(i, j0, j1)

Y = sch.get_block("Y", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)

Y_init = sch.decompose_reduction(Y, j1)
sch.parallel(n)

sch.reorder(k0, k1, j1)

In, Ii, Ij0, Ij1 = sch.get_loops(Y_init)
sch.vectorize(Ij1)

Cn, Ci, Cj0, Cj1 = sch.get_loops(C)
sch.vectorize(Cj1)

#sch.parallel(Cn)

Yn, Yi, Yj0, Yk0, Yk1, Yj1 = sch.get_loops(Y)
sch.unroll(Yk1)

IPython.display.Code(sch.mod.script(), language="python")

Error

	"message": "Traceback (most recent call last):
  3: TVMFuncCall
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFvNS_3tir8ScheduleERKNS5_6LoopRVEEE17AssignTypedLambdaIZNS0_8Registry15set_body_methodIS6_NS5_12ScheduleNodeEvJS9_EvEERSD_MT0_FT1_DpT2_EEUlS6_S9_E_EEvT_SsEUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
  1: tvm::tir::TracedScheduleNode::Parallel(tvm::tir::LoopRV const&)
  0: tvm::tir::ConcreteScheduleNode::Parallel(tvm::tir::LoopRV const&) [clone .cold]
ScheduleError: An error occurred in the schedule primitive 'parallel'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), \"float32\"], B: T.Buffer[(16, 128, 128), \"float32\"], C: T.Buffer[(16, 128, 128), \"float32\"]):
        # function attr dict
        T.func_attr({\"tir.noalias\": True, \"global_symbol\": \"bmm_relu\"})
        # body
        # with T.block(\"root\")
        Y = T.alloc_buffer([16, 128, 128], dtype=\"float32\")
        # tir.For#0
        for n in T.serial(16):
        ^^^^^^^^^^^^^^^^^^^^^^
            for i, j_0 in T.grid(128, 16):
                for j_1_init in T.serial(8):
                    with T.block(\"Y_init\"):
                        vn, vi = T.axis.remap(\"SS\", [n, i])
                        vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                        T.reads()
                        T.writes(Y[vn, vi, vj])
                        Y[vn, vi, vj] = T.float32(0)
                for j_1, kk_0, kk_1 in T.grid(8, 32, 4):
                    # tir.Block#1
                    with T.block(\"Y_update\"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
                        vn, vi = T.axis.remap(\"SS\", [n, i])
                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                        vk = T.axis.reduce(128, kk_0 * 4 + kk_1)
                        T.reads(Y[vn, vi, vj], A[vn, vi, vk], B[vn, vk, vj])
                        T.writes(Y[vn, vi, vj])
                        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
                for ax0 in T.serial(8):
                    with T.block(\"C\"):
                        vn, vi = T.axis.remap(\"SS\", [n, i])
                        vj = T.axis.spatial(128, j_0 * 8 + ax0)
                        T.reads(Y[vn, vi, vj])
                        T.writes(C[vn, vi, vj])
                        C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
    
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #1 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #1 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers
",

Yes, it is recommended to use decompose_reduction as late as possible in your schedule, to prevent this “complete block” error.