Scatter_nd and concatenate fails

scatter_nd takes 3 parameters:

  • values to put
  • indexes to where to put corresponding values
  • tensor shape (output shape)

If I use relay.concatenate to create indexes expression and then call scatter_nd then the evaluation fails with error Check failed: it != buf_map_.end() == false: Cannot find buffer of buffer(T_concat, 0x7f8eb0c0d130)

Code example:

import tvm
from tvm import relay
import numpy as np

ctx = tvm.cpu(0)
target = 'llvm'
out_shape = (3,3)

hs = [0, 1, 2, 2]
ws = [0, 1, 1, 2]
vs = [2.0, 4.0, 7.0, 9.0]

# -- Works -----
vs_var = relay.var("vs", shape=(4,), dtype='float32')

ind = relay.var("ind", shape=(2,4), dtype='int64')

y = relay.scatter_nd(vs_var, ind, out_shape)

func = relay.Function([ind, vs_var], y)
intrp = relay.create_executor("graph", ctx=ctx, target=target)
intrp.evaluate(func)([hs, ws], vs)

# -- Fails ------
hs_var = relay.var("hs", shape=(1,4), dtype='int64')
ws_var = relay.var("ws", shape=(1,4), dtype='int64')
vs_var = relay.var("vs", shape=(4,), dtype='float32')

ind = relay.concatenate([hs_var, ws_var], axis=0)

y = relay.scatter_nd(vs_var, ind, out_shape)

func = relay.Function([hs_var, ws_var, vs_var], y)
intrp = relay.create_executor("graph", ctx=ctx, target=target)
intrp.evaluate(func)([hs,], [ws,], vs)

Error log:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/backend/interpreter.py", line 178, in evaluate
    return self._make_executor(expr)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/build_module.py", line 381, in _make_executor
    mod = build(self.mod, target=self.target)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/build_module.py", line 269, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/build_module.py", line 132, in build
    self._build(mod, target, target_host)
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000113ad0dea tvm::relay::backend::MemoizedExprTranslator<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&) + 490
  [bt] (7) 8   libtvm.dylib                        0x0000000113ad3de2 tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) + 178
  [bt] (6) 7   libtvm.dylib                        0x0000000113ad40e9 tvm::NodeFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>*) const + 297
  [bt] (5) 6   libtvm.dylib                        0x0000000113ad58f8 tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<std::__1::vector<tvm::relay::backend::GraphNodeRef, std::__1::allocator<tvm::relay::backend::GraphNodeRef> > (tvm::RelayExpr const&)>*) + 24
  [bt] (4) 5   libtvm.dylib                        0x0000000113ad2a92 tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::CallNode const*) + 3394
  [bt] (3) 4   libtvm.dylib                        0x0000000113ac34c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::relay::CachedFunc (tvm::relay::CompileEngine, tvm::relay::CCacheKey)>::AssignTypedLambda<tvm::relay::$_8>(tvm::relay::$_8, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::relay::CachedFunc (tvm::relay::CompileEngine, tvm::relay::CCacheKey)>::AssignTypedLambda<tvm::relay::$_8>(tvm::relay::$_8, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 713
  [bt] (2) 3   libtvm.dylib                        0x0000000113ab4792 tvm::relay::CompileEngineImpl::Lower(tvm::relay::CCacheKey const&) + 18
  [bt] (1) 2   libtvm.dylib                        0x0000000113ab74ec tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&) + 3580
  [bt] (0) 1   libtvm.dylib                        0x0000000113c6fec5 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/backend/_backend.py", line 49, in lower
    f = tvm.driver.lower(sch, inputs, name=func_name)
  File "/Users/pivovaa/workspace/tvm/python/tvm/driver/build_module.py", line 207, in lower
    mod = optimize(mod)
  File "/Users/pivovaa/workspace/tvm/python/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  [bt] (8) 9   libtvm.dylib                        0x000000011337cc60 tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*) + 944
  [bt] (7) 8   libtvm.dylib                        0x00000001133871b2 tvm::tir::StorageFlattener::HandleBufferBindScope(tvm::tir::AttrStmtNode const*) + 5026
  [bt] (6) 7   libtvm.dylib                        0x0000000112d0585e tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) + 46
  [bt] (5) 6   libtvm.dylib                        0x0000000112d0dff5 tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&) + 53
  [bt] (4) 5   libtvm.dylib                        0x0000000112d0e2b9 tvm::NodeFunctor<tvm::tir::Stmt (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*) const + 297
  [bt] (3) 4   libtvm.dylib                        0x0000000112d0f808 tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::'lambda0'(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*) + 24
  [bt] (2) 3   libtvm.dylib                        0x000000011337cc60 tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*) + 944
  [bt] (1) 2   libtvm.dylib                        0x0000000113386444 tvm::tir::StorageFlattener::HandleBufferBindScope(tvm::tir::AttrStmtNode const*) + 1588
  [bt] (0) 1   libtvm.dylib                        0x0000000112cdc75f dmlc::LogMessageFatal::~LogMessageFatal() + 111
  File "/Users/pivovaa/workspace/tvm/src/tir/transforms/storage_flatten.cc", line 397
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/backend/_backend.py", line 57, in lower
    raise RuntimeError(msg)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/backend/_backend.py", line 49, in lower
    f = tvm.driver.lower(sch, inputs, name=func_name)
  File "/Users/pivovaa/workspace/tvm/python/tvm/driver/build_module.py", line 207, in lower
    mod = optimize(mod)
  File "/Users/pivovaa/workspace/tvm/python/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  [bt] (8) 9   libtvm.dylib                        0x000000011337cc60 tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*) + 944
  [bt] (7) 8   libtvm.dylib                        0x00000001133871b2 tvm::tir::StorageFlattener::HandleBufferBindScope(tvm::tir::AttrStmtNode const*) + 5026
  [bt] (6) 7   libtvm.dylib                        0x0000000112d0585e tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) + 46
  [bt] (5) 6   libtvm.dylib                        0x0000000112d0dff5 tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&) + 53
  [bt] (4) 5   libtvm.dylib                        0x0000000112d0e2b9 tvm::NodeFunctor<tvm::tir::Stmt (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*) const + 297
  [bt] (3) 4   libtvm.dylib                        0x0000000112d0f808 tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::'lambda0'(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*) + 24
  [bt] (2) 3   libtvm.dylib                        0x000000011337cc60 tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*) + 944
  [bt] (1) 2   libtvm.dylib                        0x0000000113386444 tvm::tir::StorageFlattener::HandleBufferBindScope(tvm::tir::AttrStmtNode const*) + 1588
  [bt] (0) 1   libtvm.dylib                        0x0000000112cdc75f dmlc::LogMessageFatal::~LogMessageFatal() + 111
  File "/Users/pivovaa/workspace/tvm/src/tir/transforms/storage_flatten.cc", line 397
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: it != buf_map_.end() == false: Cannot find buffer of buffer(T_concat, 0x7f8eb0a48be0)

During handling of the above exception, another exception occurred:

TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: it != buf_map_.end() == false: Cannot find buffer of buffer(T_concat, 0x7f8eb0a48be0)
Error during compile function
-----------------------------
#[version = "0.0.5"]
fn (%p0: Tensor[(1, 4), int64], %p1: Tensor[(1, 4), int64], %p2: Tensor[(4), float32], Primitive=1) -> Tensor[(3, 3), float32] {
  %0 = (%p0, %p1);
  %1 = concatenate(%0) /* ty=Tensor[(2, 4), int64] */;
  scatter_nd(%p2, %1, meta[relay.attrs.ScatterNDAttrs][0]) /* ty=Tensor[(3, 3), float32] */
}
#[metadata]
{
  "root": 1, 
  "nodes": [
    {
      "type_key": ""
    }, 
    {
      "type_key": "Map", 
      "keys": [
        "relay.attrs.ScatterNDAttrs"
      ], 
      "data": [2]
    }, 
    {
      "type_key": "Array", 
      "data": [3]
    }, 
    {
      "type_key": "relay.attrs.ScatterNDAttrs", 
      "attrs": {"out_shape": "4"}
    }, 
    {
      "type_key": "Array", 
      "data": [5, 6]
    }, 
    {
      "type_key": "IntImm", 
      "attrs": {
        "dtype": "int32", 
        "span": "0", 
        "value": "3"
      }
    }, 
    {
      "type_key": "IntImm", 
      "attrs": {
        "dtype": "int32", 
        "span": "0", 
        "value": "3"
      }
    }
  ], 
  "b64ndarrays": [], 
  "attrs": {"tvm_version": "0.8.dev0"}
}

The error is probably because scatter_nd uses schedule_extern tvm/cuda.py at a1260cc19342c4db61c6942a11c2b2b2b58f8bad · apache/tvm · GitHub

To be able to compose with other ops, we need to inline injective ops like this

Thank you Masahiro!

@tkonolige @haan What you think about this issue?

Should we make scatter_nd kOpaque given both CPU and CUDA schedules are extern? @masahi

yes I think it should be kOpaque

@apivovarov Can you please try with kOpaque fix and see if it works, can you send a PR?

Adding the following at the beginning of the test script helped! Thank you, Animesh!

relay.op.register_pattern("scatter_nd", relay.op.OpPattern.OPAQUE, level=11)

@masahi Should I also add the following line to RELAY_REGISTER_OP(“scatter_nd”)?

set_attr<TOpIsStateful>("TOpIsStateful", false)

Both scatter and scatter_add have it.

set_support_level is also different among scatter ops.

scatter - 10
scatter_add - 10
scatter_nd - 3

I don’t know what TOpIsStateful is for

PR to set TOpPattern=kOpaque for scatter_nd

Can we add a test to the codebase to forbid fusing with schedule extern? Or to verify that ops with schedule_extern are opaque?