How can I output a intermediate result to global buffer and also keep the intermediate result in local buffer to be calculated next. the calculating graph is as below
I tried code like this:
import tvm
shape = (255,255, 255)
data = tvm.placeholder(shape, "float16", "data")
a = tvm.compute(shape, lambda i,j,k:data[i,j,k] + 1, name = "a")
tmp = tvm.compute((255, 255), lambda i,j:a[i,j,0], name = "tmp")
b = tvm.compute(shape, lambda i,j,k:a[i,j,k] + 1, name = "b")
c = tvm.compute(shape, lambda i,j,k:b[i,j,k] + 1, name = "b")
s = tvm.create_schedule([c.op, tmp.op])
print tvm.lower(s, [c, data, tmp], simple_mode = True)
it gets right code, but if I try to schedule this by compute_at
, an error occurs
Traceback (most recent call last):
File "./tmp_ddr.py", line 22, in <module>
print tvm.lower(s, [c, data, tmp], simple_mode = True)
File "/nnvm/tvm/python/tvm/build_module.py", line 330, in lower
bounds = schedule.InferBound(sch)
File "/nnvm/tvm/python/tvm/_ffi/function.py", line 280, in my_api_func
return flocal(*args)
File "/nnvm/tvm/python/tvm/_ffi/_ctypes/function.py", line 183, in __call__
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File "/nnvm/tvm/python/tvm/_ffi/base.py", line 66, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [10:48:14]
/nnvm/tvm/src/schedule/bound.cc:168: Check failed: found_attach || stage_attach.size() == 0 Invalid Schedule, cannot find the producer compute(a, 0x2dc9ab0) along the loop nest specified by compute_at of consumer compute(b, 0x2dcb090)