To make my previous question more specific, I am attaching my exercise code taking the second approach. I first build compute for conv(pool) with TOPIs. Then, I write a schedule for it. This schedule code tries to reuse existing generic schedules as much as possible: it calls generic.schedule_conv2d_nchw directly for the conv TOPI and the code for the pooling TOPI schedule are mostly copied from generic.schedule_pool (schedule_pool defined in cuda/pooling.py). This code generates working CUDA code
.
Although code refactoring will greatly mitigate duplicating code, it become much more difficult to handle a slightly more complicated design: add(conv0, conv1). I am still working on it and may need to modify/hack TVM/TOPI to make the prototype working.
Another experience learnt from the working example is that I have to do manual partition. Notice that I schedule convolution with the final output tensor and pooling with its output tensor. It may be obvious in the conv(pool) case, but not for the add(conv0,conv1) case, where only one conv fuses the add. But, which one?
import numpy as np
import tvm
import topi
from topi import tag
from topi import generic
from topi.util import traverse_inline
def eval_topi() :
# define compute
input = np.random.rand(1, 8, 32, 32).astype(np.float32)
INPUT = tvm.placeholder(input.shape, name='INPUT')
kernel = np.random.rand(8, 8, 3, 3).astype(np.float32)
KERNEL = tvm.placeholder(kernel.shape, name='KERNEL')
with tvm.target.create("cuda"):
pool = topi.nn.pool(INPUT, (2,2), (2,2), (0,0,0,0), "avg")
conv = topi.nn.conv2d(pool, KERNEL, strides=(1,1), padding=(1,1), dilation=(1,1))
def create_schedule_pool_conv(outs0, outs, layout) :
s = topi.generic.schedule_conv2d_nchw(outs0)
def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].compute_inline()
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "local")
else:
Out = outs[0].op.output(0)
s[Pool].set_scope("local")
fused = s[Out].fuse(*s[Out].op.axis)
bx, tx = s[Out].split(fused, factor=num_thread)
s[Out].bind(bx, tvm.thread_axis("blockIdx.x"))
s[Out].bind(tx, tvm.thread_axis("threadIdx.x"))
if Pool.op in s.outputs:
s[OL].compute_at(s[Out], tx)
else:
s[Pool].compute_at(s[Out], tx)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
# if OP not in s.outputs:
# s[OP].compute_inline()
for tensor in OP.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
PaddedInput = OP.input_tensors[0]
Pool = OP.output(0)
_schedule(PaddedInput, Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
# schedule
with tvm.target.cuda():
s_topi = create_schedule_pool_conv([conv], [pool], "NCHW")
print(tvm.lower(s_topi, [INPUT, KERNEL], simple_mode=True))
func = tvm.build(s_topi, [INPUT, KERNEL, conv], target="cuda", name="pool_conv_topi")
print(func.imported_modules[0].get_source())
# execute
ctx = tvm.gpu(0)
input_data = tvm.nd.array(input, ctx)
kernel_data = tvm.nd.array(kernel, ctx)
output_data = tvm.nd.array(np.zeros((1, 8, 16, 16), dtype=input_data.dtype), ctx)
func(input_data, kernel_data, output_data)
evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
print('Pool_Conv: %f ms' % (evaluator(input_data, kernel_data, output_data).mean * 1e3))
if __name__ == "__main__":
eval_topi()