How should I schedule an op that contains a multi-out stage

I have an op like this:

y = f(x)
z1 = g1(y)
z2 = g2(y)
return h(z1, z2)

Each of f, g1, g2 and h contains a 1-to-n loop, so it would be better to fuse these four stages together. However, I can’t use compute_at to make y computed insided z1. TVM reports an error. I think it’s because y is used by z2, too.

I found this problem also occurs in softmax, so I had a look at its default schedule for CUDA here. I found it using 4 kernels to compute just a softmax, which was obviously not optimal.

Can I fuse these four stages using the current schedule commands? If not, is it possible, for me or for other developers, to add a new schedule command for this?

You can make y,z1,z2 compute_at h

You can look at the softmax CPU schedule as an example. All 4 stages are under the same loop, allowing for better parallelism.

Let’s have a look at this simple example.

import tvm
import topi
import numpy as np

dtype = "float32"
target = "llvm"
N = 1024

inp = tvm.placeholder((N, N), dtype=dtype, name="inp")

x = tvm.compute((N, N), lambda i, j: inp[i, j] + 1)

j = tvm.reduce_axis((0, N))
y = tvm.compute((N,), lambda i: tvm.sum(tvm.sin(x[i, j]), axis=[j]))

j = tvm.reduce_axis((0, N))
z = tvm.compute((N,), lambda i: tvm.sum(tvm.cos(x[i, j]), axis=[j]))

oup = tvm.compute((N,), lambda i: y[i] * z[i])

s = tvm.create_schedule([oup.op])
compute = tvm.build(s, [inp, oup], target, name="run")

#s[y].compute_at(s[oup], oup.op.axis[0]) # OK if scheduling like this
s[x].compute_at(s[oup], oup.op.axis[0]) # ERROR HERE

print(tvm.lower(s, [inp, oup], simple_mode=True, name="run"))

If I make s[y] compute_at s[oup], it works fine. But if I make s[x] compute_at s[oup], TVM reports

Traceback (most recent call last):

  File "tmp.py", line 27, in <module>
    print(tvm.lower(s, [inp, oup], simple_mode=True, name="run"))

  File "/home/rd/src/tvm_experiments/tvm-dev/python/tvm/build_module.py", line 382, in lower
    stmt = form_body(sch)

  File "/home/rd/src/tvm_experiments/tvm-dev/python/tvm/build_module.py", line 332, in form_body
    bounds = schedule.InferBound(sch)

  File "/home/rd/src/tvm_experiments/tvm-dev/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (4) /home/rd/src/tvm_experiments/tvm-dev/build/libtvm.so(TVMFuncCall+0x48) [0x7f079bb53d68]
  [bt] (3) /home/rd/src/tvm_experiments/tvm-dev/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::Map<tvm::IterVar, tvm::Range, void, void> (tvm::Schedule const&)>::AssignTypedLambda<tvm::Map<tvm::IterVar, tvm::Range, void, void> (*)(tvm::Schedule const&)>(tvm::Map<tvm::IterVar, tvm::Range, void, void> (*)(tvm::Schedule const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x45) [0x7f079b3ac545]
  [bt] (2) /home/rd/src/tvm_experiments/tvm-dev/build/libtvm.so(tvm::schedule::InferBound(tvm::Schedule const&)+0xf1c) [0x7f079b6f443c]
  [bt] (1) /home/rd/src/tvm_experiments/tvm-dev/build/libtvm.so(tvm::schedule::InferRootBound(tvm::Stage const&, tvm::schedule::GraphContext const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > >*)+0x1ab5) [0x7f079b6f23d5]
  [bt] (0) /home/rd/src/tvm_experiments/tvm-dev/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x33) [0x7f079b366a73]
  File "/home/rd/src/tvm_experiments/tvm-dev/src/schedule/bound.cc", line 187
TVMError: Check failed: found_attach || stage_attach.size() == 0: Invalid Schedule, cannot find the producer compute(compute, 0x2427820) along the loop nest specified by compute_at of consumer compute(compute, 0x2460a40)

Why can’t this schedule happen?