Schedule composite OPs built with TOPI

I am new to TVM/Tensor Expression and doing an exercise of building a new OP by connecting existing TOPI OPs, such as pooling followed by a convolution, targeting CUDA. I can construct compute with existing TOPIs, such as a topi.nn.conv2d followed a topi.nn.pool. After learning from tutorials, docs, codes, and tests, I was able to schedule them individually with generic.schedule_conv2d_nchw and generic.schedule_pool. But I am not sure what’s the proper approach to build a schedule for the combined OP. One thought is to construct a new schedule with both existing generic schedules. But I don’t know how to combine two schedules for one compute. Another approach is to construct a new schedule from scratch. Meanwhile, I still would like to reuse some part of existing schedules, such as tiling or threads binding. Those examples in “https://docs.tvm.ai/tutorials/topi/intro_topi.html” do not show how to schedule TOPIs that can’t be fused. My first question is where to find such examples.

Further, I wonder whether TOPI generic schedules are supposed to be reused as building blocks higher than scheduling primitives. I notice that generic.schedule_conv2d_nchw has its own heuristics to fuse OPs. If I somehow can use it to fuse part of my composite OP, I still need to understand how much it does for me and what’s left to do by myself. They look too powerful. If not, how to reuse scheduling logic?

2 Likes

To make my previous question more specific, I am attaching my exercise code taking the second approach. I first build compute for conv(pool) with TOPIs. Then, I write a schedule for it. This schedule code tries to reuse existing generic schedules as much as possible: it calls generic.schedule_conv2d_nchw directly for the conv TOPI and the code for the pooling TOPI schedule are mostly copied from generic.schedule_pool (schedule_pool defined in cuda/pooling.py). This code generates working CUDA code :slight_smile: .

Although code refactoring will greatly mitigate duplicating code, it become much more difficult to handle a slightly more complicated design: add(conv0, conv1). I am still working on it and may need to modify/hack TVM/TOPI to make the prototype working.

Another experience learnt from the working example is that I have to do manual partition. Notice that I schedule convolution with the final output tensor and pooling with its output tensor. It may be obvious in the conv(pool) case, but not for the add(conv0,conv1) case, where only one conv fuses the add. But, which one?

import numpy as np
import tvm
import topi
from topi import tag
from topi import generic
from topi.util import traverse_inline
def eval_topi() :
    # define compute
    input = np.random.rand(1, 8, 32, 32).astype(np.float32)
    INPUT = tvm.placeholder(input.shape, name='INPUT')
    kernel = np.random.rand(8, 8, 3, 3).astype(np.float32)
    KERNEL = tvm.placeholder(kernel.shape, name='KERNEL')
    with tvm.target.create("cuda"):
        pool = topi.nn.pool(INPUT, (2,2), (2,2),  (0,0,0,0), "avg")
        conv = topi.nn.conv2d(pool, KERNEL, strides=(1,1), padding=(1,1), dilation=(1,1))
    def create_schedule_pool_conv(outs0, outs, layout) :
            s = topi.generic.schedule_conv2d_nchw(outs0)
            def _schedule(PaddedInput, Pool):
            if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
                s[PaddedInput].compute_inline()
            num_thread = tvm.target.current_target(allow_none=False).max_num_threads
            if Pool.op in s.outputs:
                Out = Pool
                OL = s.cache_write(Pool, "local")
            else:
                Out = outs[0].op.output(0)
                s[Pool].set_scope("local")
            fused = s[Out].fuse(*s[Out].op.axis)
            bx, tx = s[Out].split(fused, factor=num_thread)
            s[Out].bind(bx, tvm.thread_axis("blockIdx.x"))
            s[Out].bind(tx, tvm.thread_axis("threadIdx.x"))
            if Pool.op in s.outputs:
                s[OL].compute_at(s[Out], tx)
            else:
                s[Pool].compute_at(s[Out], tx)
            scheduled_ops = []
            def traverse(OP):
            """Internal travserse function"""
            # inline all one-to-one-mapping operators except the last stage (output)
            if tag.is_broadcast(OP.tag):
                # if OP not in s.outputs:
                #     s[OP].compute_inline()
                for tensor in OP.input_tensors:
                    if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
                        traverse(tensor.op)
            # schedule pool
            elif OP.tag.startswith('pool'):
                PaddedInput = OP.input_tensors[0]
                Pool = OP.output(0)
                _schedule(PaddedInput, Pool)
            else:
                raise RuntimeError("Unsupported operator: %s" % OP.tag)
            scheduled_ops.append(OP)
            traverse(outs[0].op)
            return s
    # schedule
    with tvm.target.cuda():
        s_topi = create_schedule_pool_conv([conv], [pool], "NCHW")
        print(tvm.lower(s_topi, [INPUT, KERNEL], simple_mode=True))
        func = tvm.build(s_topi, [INPUT, KERNEL, conv], target="cuda", name="pool_conv_topi")
        print(func.imported_modules[0].get_source())
    # execute
    ctx = tvm.gpu(0)
    input_data = tvm.nd.array(input, ctx)
    kernel_data = tvm.nd.array(kernel, ctx)
    output_data = tvm.nd.array(np.zeros((1, 8, 16, 16), dtype=input_data.dtype), ctx)
    func(input_data, kernel_data, output_data)
    evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
    print('Pool_Conv: %f ms' % (evaluator(input_data, kernel_data, output_data).mean * 1e3))
if __name__ == "__main__":
    eval_topi()