Consider this snippet of code: import tvm from tvm import te
ELEM_BITS = 256
BUFF_SIZE = 100*1024*1024
buffer1 = "local.buffer1"
buffer2 = "local.buffer2"
buffer3 = "local.buffer3"
@tvm.register_func("tvm.info.mem.%s" % buffer1)
def mem_info_buffer1():
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=ELEM_BITS,
max_simd_bits=ELEM_BITS,
max_num_bits=BUFF_SIZE*8,
head_address=None,
)
@tvm.register_func("tvm.info.mem.%s" % buffer2)
def mem_info_buffer2():
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=ELEM_BITS,
max_simd_bits=ELEM_BITS,
max_num_bits=BUFF_SIZE*8,
head_address=None,
)
@tvm.register_func("tvm.info.mem.%s" % buffer3)
def mem_info_buffer3():
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=ELEM_BITS,
max_simd_bits=ELEM_BITS,
max_num_bits=BUFF_SIZE*8,
head_address=None,
)
input=te.placeholder((56*56*4*64, ), dtype="int16", name="input")
middle=te.compute(
(56, 56, 4, 64),
lambda i, j, m, n: input[((i*2+m//2)*112+m%2+j*2)*64+n].astype("int16"),
name="middle"
)
channel_out = 64
pooling_area = 4
height = 56
weight = 56
k = te.reduce_axis((0, pooling_area), name="k")
output = te.compute(
(height, weight, channel_out),
lambda i, j, m: (te.max(middle[i, j, k, m], axis=k)).astype("int16"),
name="output"
)
s = te.create_schedule(output.op)
cinput = s.cache_read(input, buffer1, [middle])
cmiddle = s.cache_read(middle, buffer2, [output])
coutput = s.cache_write(output, buffer3)
h, w, c = s[output].op.axis
i,j,m,n = s[middle].op.axis
k_out, k_in = s[output].split(k, factor=2)
s[output].set_scope("buffer3")
code = tvm.lower(s, [input, middle, output], simple_mode=True)
print(code)
when i run this code, i got this: > Traceback (most recent call last):
File "/home/tonywu/.config/JetBrains/PyCharm2020.3/scratches/scratch_10.py", line 71, in <module> k_out, k_in = s[output].split(k, factor=2) File "/home/tonywu/Documents/tvm/python/tvm/te/schedule.py", line 230, in split outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor) File "/home/tonywu/Documents/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm._ffi.base.TVMError: Traceback (most recent call last): [bt] (5) /home/tonywu/Documents/tvm/build/libtvm.so(TVMFuncCall+0x63) [0x7fef086dabd3] [bt] (4) /home/tonywu/Documents/tvm/build/libtvm.so(+0x9eaf57) [0x7fef07d95f57] [bt] (3) /home/tonywu/Documents/tvm/build/libtvm.so(tvm::te::Stage::split(tvm::tir::IterVar, tvm::PrimExpr, tvm::tir::IterVar*, tvm::tir::IterVar*)+0x68) [0x7fef07d8ed58] [bt] (2) /home/tonywu/Documents/tvm/build/libtvm.so(tvm::te::SplitHelper(tvm::te::StageNode*, tvm::tir::IterVar, tvm::PrimExpr, tvm::PrimExpr, tvm::tir::IterVar*, tvm::tir::IterVar*)+0x1ef) [0x7fef07d8e6cf] [bt] (1) /home/tonywu/Documents/tvm/build/libtvm.so(tvm::te::FindLeafVar(tvm::runtime::ArrayNode*, tvm::runtime::ArrayNode*, tvm::tir::IterVar const&)+0xca) [0x7fef07d8a5ba] [bt] (0) /home/tonywu/Documents/tvm/build/libtvm.so(+0x9de6c8) [0x7fef07d896c8] File "/home/tonywu/Documents/tvm/src/te/schedule/schedule_lang.cc", line 53 TVMError: Operate on iter var iter_var(k, range(min=0, ext=4))that is not part of the schedule
If I delete the cache_write or split sentence, the error will go away. I don’t know why this two sentence is in connflict. How can I avoid this bug?