Recommended practice on scheduling TVMScript programs.

I have been exploring the use of TVMScript to implement “long tail” kernels for NVIDIA GPUs recently, but I encountered some issues while attempting to tune them with MetaSchedule. To be more specific, I outline four cases that represent common logic in kernel development where errors occur:

1. for loop starting from non-zero

@T.prim_func
def main(
    A: T.Buffer((100,), "float16"),
    B: T.Buffer((100,), "float16"),
):
    for i_ in range(1, 100):
        with T.block("block"):
            i = T.axis.remap("S", [i_,])
            B[i] = A[i] + 1
Error message: The loop tir.For#0 does not start with 0, which is not supported

2. prefix sum

@T.prim_func
def main(
    A: T.Buffer((100,), "float16"),
    B: T.Buffer((100,), "float16"),
):
    B[0] = 0
    for i_ in range(99):
        with T.block("block"):
            i = T.axis.remap("R", [i_,])
            B[i+1] = B[i] + A[i+1]
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #1 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #1 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers

3. inplace operation

@T.prim_func
def main(
    A: T.Buffer((100,), "float16"),
    B: T.Buffer((100,), "float16"),
):
    for i_ in range(100):
        with T.block("block"):
            i = T.axis.remap("S", [i_,])
            B[i] = B[i] + A[i]
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #3 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #1 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers

4. nested loop mixed with if condition

@T.prim_func
def main(A: T.Buffer((100,100), "float16")):
    for i in range(100):
        if i % 2 == 0:
            for j in range(100):
                with T.block("block1"):
                    A[i, j] = 1
        else:
            for j in range(100):
                with T.block("block2"):
                    A[i, j] = 2
Error message: The loops can't be fused because the inner loop tir.For#1 is not the only child of outer loop tir.For#0.

The cases above are not comprehensive, and they are just illustrative examples. The error messages clearly indicate the constraints of the tuning system, therefore I guess these are not bugs but limitations. I also tried using the tuning-free DLight, which appears to support even narrower programs than MetaSchedule.

Given this, I would like to ask a few questions: When implementing customized “long tail” kernels (like NMS, as opposed to common workloads like GEMM or injective operations), how should we tune the kernels to achieve optimal performance, are manual implementation of optimized target-specific versions or manually schedule it with schedule primitives the only viable options? What are the best practice to tune a TVMScript program?

I would greatly appreciate any guidance or suggestions from the community on these issues, which would be useful in helping me and others better understand and optimize the use of TVMScript. Thanks a lot in advance for your help and support.

1 Like

I would suggest to write optimized kernel directly for such case. i.e. write the TVMScripts that are optimized already and can build without schedule/tuning.