Hi, I’m finishing practice about handwritten scheduling on tianqi’s machine learning compilation course. I found that the code could not pass the compilation after I changed the location of a parallel intrinsic.
@tvm.script.ir_module
class MyBmmRelu:
@T.prim_func
def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]):
T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
Y = T.alloc_buffer([16, 128, 128], dtype="float32")
for n, i, j, kk in T.grid(16, 128, 128, 128):
with T.block("Y"):
vn,vi,vj,vk = T.axis.remap("SSSR", [n, i, j, kk])
with T.init():
Y[vn,vi,vj]=T.float32(0)
Y[vn,vi,vj]=Y[vn,vi,vj]+A[vn,vi,vk]*B[vn,vk,vj]
for n,i,j in T.grid(16,128,128):
with T.block("C"):
vn,vi,vj=T.axis.remap("SSS",[n,i,j])
C[vn,vi,vj]=T.max(Y[vn,vi,vj],T.float32(0))
#C[vn,vi,vj]=T.float32(0)
The following code could pass the compilation
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.
# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")
# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
j0, j1 = sch.split(j, factors=[16, 8])
k0, k1 = sch.split(k, factors = [32, 4])
sch.reorder(i, j0, j1)
Y = sch.get_block("Y", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)
sch.parallel(n)
Y_init = sch.decompose_reduction(Y, j1)
sch.reorder(k0, k1, j1)
In, Ii, Ij0, Ij1 = sch.get_loops(Y_init)
sch.vectorize(Ij1)
Cn, Ci, Cj0, Cj1 = sch.get_loops(C)
sch.vectorize(Cj1)
#sch.parallel(Cn)
Yn, Yi, Yj0, Yk0, Yk1, Yj1 = sch.get_loops(Y)
sch.unroll(Yk1)
IPython.display.Code(sch.mod.script(), language="python")
Such could not(notice that the location of the parallel changed)
sch = tvm.tir.Schedule(MyBmmRelu)
# TODO: transformations
# Hints: you can use
# `IPython.display.Code(sch.mod.script(), language="python")`
# or `print(sch.mod.script())`
# to show the current program at any time during the transformation.
# Step 1. Get blocks
Y = sch.get_block("Y", func_name="bmm_relu")
C = sch.get_block("C", func_name="bmm_relu")
# Step 2. Get loops
n, i, j, k = sch.get_loops(Y)
j0, j1 = sch.split(j, factors=[16, 8])
k0, k1 = sch.split(k, factors = [32, 4])
sch.reorder(i, j0, j1)
Y = sch.get_block("Y", func_name="bmm_relu")
sch.reverse_compute_at(C, j0)
Y_init = sch.decompose_reduction(Y, j1)
sch.parallel(n)
sch.reorder(k0, k1, j1)
In, Ii, Ij0, Ij1 = sch.get_loops(Y_init)
sch.vectorize(Ij1)
Cn, Ci, Cj0, Cj1 = sch.get_loops(C)
sch.vectorize(Cj1)
#sch.parallel(Cn)
Yn, Yi, Yj0, Yk0, Yk1, Yj1 = sch.get_loops(Y)
sch.unroll(Yk1)
IPython.display.Code(sch.mod.script(), language="python")
Error
"message": "Traceback (most recent call last):
3: TVMFuncCall
2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFvNS_3tir8ScheduleERKNS5_6LoopRVEEE17AssignTypedLambdaIZNS0_8Registry15set_body_methodIS6_NS5_12ScheduleNodeEvJS9_EvEERSD_MT0_FT1_DpT2_EEUlS6_S9_E_EEvT_SsEUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
1: tvm::tir::TracedScheduleNode::Parallel(tvm::tir::LoopRV const&)
0: tvm::tir::ConcreteScheduleNode::Parallel(tvm::tir::LoopRV const&) [clone .cold]
ScheduleError: An error occurred in the schedule primitive 'parallel'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def bmm_relu(A: T.Buffer[(16, 128, 128), \"float32\"], B: T.Buffer[(16, 128, 128), \"float32\"], C: T.Buffer[(16, 128, 128), \"float32\"]):
# function attr dict
T.func_attr({\"tir.noalias\": True, \"global_symbol\": \"bmm_relu\"})
# body
# with T.block(\"root\")
Y = T.alloc_buffer([16, 128, 128], dtype=\"float32\")
# tir.For#0
for n in T.serial(16):
^^^^^^^^^^^^^^^^^^^^^^
for i, j_0 in T.grid(128, 16):
for j_1_init in T.serial(8):
with T.block(\"Y_init\"):
vn, vi = T.axis.remap(\"SS\", [n, i])
vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
T.reads()
T.writes(Y[vn, vi, vj])
Y[vn, vi, vj] = T.float32(0)
for j_1, kk_0, kk_1 in T.grid(8, 32, 4):
# tir.Block#1
with T.block(\"Y_update\"):
^^^^^^^^^^^^^^^^^^^^^^^^^
vn, vi = T.axis.remap(\"SS\", [n, i])
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, kk_0 * 4 + kk_1)
T.reads(Y[vn, vi, vj], A[vn, vi, vk], B[vn, vk, vj])
T.writes(Y[vn, vi, vj])
Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
for ax0 in T.serial(8):
with T.block(\"C\"):
vn, vi = T.axis.remap(\"SS\", [n, i])
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vn, vi, vj])
T.writes(C[vn, vi, vj])
C[vn, vi, vj] = T.max(Y[vn, vi, vj], T.float32(0))
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #1 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #1 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers
",