Check failed: (it != rmap->end()) when using relay.build in model with TFLite_Detection_PostProcess (with minimal example)

I am importing a Tensorflow Lite model into TVM using the following script:

import os
import numpy as np
import pathlib

import tvm
from tvm import relay
import tensorflow as tf 

tflite_file = "int8.tflite"
tflite_model_buf = open(tflite_file, "rb").read()
try:
    import tflite
    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
    import tflite.Model
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
    
version = tflite_model.Version()
print("Model Version: " + str(version))

input_tensor = "tflite_images:0"
input_shape = tuple([1,512, 512, 3])
input_dtype = "int8"

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=tflite_file)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()

print("Importing TFLite model...")
mod, params = relay.frontend.from_tflite(
    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
)

RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True})
TARGET = tvm.target.target.Target({"kind": "c"})
EXECUTOR = tvm.relay.backend.Executor("aot",options={'interface-api': 'c','unpacked-api': 1})
#EXECUTOR = tvm.relay.backend.Executor("graph",{"link-params": True})

print("Building model with TVM...")
with tvm.transform.PassContext(
    opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["AlterOpLayout"]
):
    module = relay.build(mod, executor=EXECUTOR, runtime=RUNTIME, target=TARGET, params=params)

No matter which executor I use (graph executor or the aot executor), i get the error:

Traceback (most recent call last):
  File "main.py", line 70, in <module>
    module = relay.build(mod, executor=EXECUTOR, runtime=RUNTIME, target=TARGET, params=params)
  File "/local_disk/local_sw/tvm/python/tvm/relay/build_module.py", line 416, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/local_disk/local_sw/tvm/python/tvm/relay/build_module.py", line 154, in build
    self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name)
  File "/local_disk/local_sw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  40: TVMFuncCall
  39: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  38: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  37: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  36: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  35: tvm::transform::Pass::operator()(tvm::IRModule) const
  34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  33: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  32: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  31: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  30: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec11LowerTEPassERKNS0_6StringESt8functionIFvNS_8BaseFuncEEENS_13VirtualDeviceEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SN_SR_
  29: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::VirtualDevice)
  28: tvm::transform::Pass::operator()(tvm::IRModule) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_13VirtualDeviceEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
  24: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  22: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  21: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  20: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
  19: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  18: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  17: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  16: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
  15: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
  14: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  12: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  11: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  9: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  8: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  7: tvm::relay::tec::LowerTensorExprMutator::MakeLoweredCall(tvm::relay::Function, tvm::runtime::Array<tvm::RelayExpr, void>, tvm::Span, tvm::Target)
  6: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&, tvm::runtime::String)
  5: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, std::function<tvm::runtime::String (tvm::runtime::String)>)
  4: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::te::Tensor, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  3: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  2: tvm::ScheduleToModule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&)
  1: tvm::te::InferBound(tvm::te::Schedule const&)
  0: tvm::te::InferRootBound(tvm::te::Stage const&, tvm::te::GraphContext const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*)
  File "/local_disk/local_sw/tvm/src/te/schedule/bound.cc", line 144
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (it != rmap->end()) is false: 

What could be the cause of this behavior?

UPDATE:

The problem occurs when processing the argsort_nms_cpu stage which is inside non_max_suppression.

Just to give more information, my TFLite model has a TFLite_Detection_PostProcess node, and I can see in the frontend of TVM that this operator is correctly detected by the frontend, which calls function convert_detection_postprocess().

This function inserts the operator _op.vision.non_max_suppression, which I noticed is the one where the LowerSchedule method is failing.

Update

I found that this error appears when trying to build using any executor but ONLY when the option link-params of the executor is True.

Building using Graph executor

When I build using the graph executor and link-params False, the relay.build finishes correctly but I get an error in the generated C code when I try to compile it:

'__tvm_module_ctx' undeclared

I noticed that the generation of this function is done in src/target/source/codegen_c_host.cc, but this is only done when the executor is AOT. So the real question is, why is this __tvm_module_ctx trying to be used?

The line in the C code where this is used is the following:

if (TVMBackendGetFuncFromEnv(__tvm_module_ctx, "tvm.contrib.sort.argsort_nms", &tvm_contrib_sort_argsort_nms_packed) != 0)

Inside the function generated for the non max suppresion operator (tvmgen_default_fused_vision_non_max_suppression).

Looking at the generated code, I realized that this line is generated by PrintGetFuncFromBackend in src/target/source/codegen_c_host.cc, when the operator is the same as builtin::tvm_call_packed_lowered(). The question now would be why this operator is being called this way? What is special about it?

Looking at python/tvm/topi/sort.py I realized there it is, the tvm.tir.call_packed which is generating this code.

Is there any workaround for this?

Minimal example to reproduce

I just tested this code using the main branch. I was able to obtain the same error but with a minimal example:

from tvm.relay.op.vision import non_max_suppression
from tvm import te
import numpy as np
from tvm import topi
import tvm
from tvm import relay

x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32"))
x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
z = relay.vision.non_max_suppression(
    x0,
    x1,
    x2,
    x3,
    iou_threshold=0.5,
    force_suppress=True,
    top_k=2,
    return_indices=True,
    invalid_to_bottom=False,
)
z = z.astuple()
func = relay.Function([x0, x1, x2, x3], z)
mod = tvm.IRModule()
mod["main"] = func

print(mod["main"])
RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True})
TARGET = tvm.target.target.Target({"kind": "c"})
EXECUTOR = tvm.relay.backend.Executor("aot",options={'interface-api': 'packed','unpacked-api': 0, 'link-params': True})
mod = relay.build(mod, executor=EXECUTOR, target=TARGET,runtime=RUNTIME)

print(mod)

Console output (first print of mod[“main”] and then the error):

fn (%x0: Tensor[(1, ?, 6), float32], %x1: Tensor[(1), int32], %x2: Tensor[(1, ?), int32], %x3: int32) {
  vision.non_max_suppression(%x0, %x1, %x2, %x3, 0.5f, force_suppress=True, top_k=2)
}
[18:10:19] /local_disk/local_sw/tvm_main/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(hybrid_nms, 0x349b6a0)
Traceback (most recent call last):
  File "test.py", line 32, in <module>
    mod = relay.build(mod, executor=EXECUTOR, target=TARGET,runtime=RUNTIME)
  File "/local_disk/local_sw/tvm_main/tvm/python/tvm/relay/build_module.py", line 416, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/local_disk/local_sw/tvm_main/tvm/python/tvm/relay/build_module.py", line 154, in build
    self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name)
  File "/local_disk/local_sw/tvm_main/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  40: TVMFuncCall
  39: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  38: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  37: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  36: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  35: tvm::transform::Pass::operator()(tvm::IRModule) const
  34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  33: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  32: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  31: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  30: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec11LowerTEPassENS0_6StringESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
  29: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
  28: tvm::transform::Pass::operator()(tvm::IRModule) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
  24: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  22: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  21: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  20: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
  19: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  18: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  17: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  16: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
  15: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
  14: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  12: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  11: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  9: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  8: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  7: tvm::relay::tec::LowerTensorExprMutator::MakeLoweredCall(tvm::relay::Function, tvm::runtime::Array<tvm::RelayExpr, void>, tvm::Span, tvm::Target)
  6: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&, tvm::runtime::String)
  5: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, std::function<tvm::runtime::String (tvm::runtime::String)>)
  4: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::te::Tensor, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  3: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  2: tvm::ScheduleToModule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&)
  1: tvm::te::InferBound(tvm::te::Schedule const&)
  0: tvm::te::InferRootBound(tvm::te::Stage const&, tvm::te::GraphContext const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*)
  File "/local_disk/local_sw/tvm_main/tvm/src/te/schedule/bound.cc", line 144
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (it != rmap->end()) is false: 

After some more analysis, I was able to build using the AOT executor with the following code:

from tvm.relay.op.vision import non_max_suppression
from tvm import te
import numpy as np
from tvm import topi
import tvm
from tvm import relay

# An example to use non_max_suppression

x0 = relay.var("x0", relay.ty.TensorType((1, 100, 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
x2 = relay.var("x2", relay.ty.TensorType((1, 100), "int32"))
x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
z = relay.vision.non_max_suppression(
    x0,
    x1,
    x2,
    x3,
    iou_threshold=0.5,
    force_suppress=True,
    top_k=2,
    return_indices=True,
    invalid_to_bottom=False,
)
z = z.astuple()
func = relay.Function([x0, x1, x2, x3], z)
mod = tvm.IRModule()
mod["main"] = func

print(mod["main"])
RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True})
TARGET = tvm.target.target.Target({"kind": "c"})
EXECUTOR = tvm.relay.backend.Executor("aot",options={'interface-api': 'c','unpacked-api': 1, 'link-params': True})
mod = relay.build(mod, executor=EXECUTOR, target=TARGET,runtime=RUNTIME)

print(mod)

I was able to build it after modifying method GetNewArguments in file src/relay/transforms/fuse_ops.cc.

Previous code:

...
if (!link_params_ || new_arg.as<ConstantNode>() == nullptr) {
    Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
    new_args.push_back(param);
} else {
    new_args.push_back(new_arg);
}
...

Modified code:

...
if (!link_params_ || new_arg.as<ConstantNode>() == nullptr) {
    Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
    new_args.push_back(param);
} else {
    Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
    new_args.push_back(param);
}
...

I would like to understand why is there a distinction in the GetNewArguments method, that pushes different args into new_args. I am not sure if I found a bug, or if there is some other previous error (perhaps in the definition of non_max_suppression).

Created an issue on Github to follow this bug.