Cann't pickle IRModule with gather_nd operator

Hi, I’ve met a strange problem. For my work, I need to save IRModule and Parameters which is constructed from custom dnn model.

Problem: Once the IRModule contains gather_nd, the pickle file can’t be load successfully.

here is my code to save IRModule

from tvm import relay
import tvm
import pickle
from tvm.relay import expr as _expr
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape
import numpy as np

def gather_nd(data, indices):
    indices_dims = len(infer_shape(indices))
    indices = _op.transpose(indices, axes=[-1]+list(range(indices_dims-1)))
    out =  _op.gather_nd(data, indices, 0)
    return out

ipt0 = relay.var('x', shape=[2, 15, 4], dtype='float32')
index = np.array([[0,5], [1, 2], [0, 7]]).astype('int64')
ipt1 = _expr.const(index)

out = gather_nd(ipt0, ipt1)

args = relay.analysis.free_vars(out)
func = relay.Function(args, out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)

with open('relay.pkl', 'wb') as f:
    pickle.dump(mod, f)

Error happens while loading the pickle file

import pickle
with open('relay.pkl', 'rb') as f:
    mod = pickle.load(f)

error messge

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/ssd5/jiangjiajun/tvm/python/tvm/runtime/object.py", line 91, in __setstate__
    self.__init_handle_by_constructor__(_ffi_node_api.LoadJSON, handle)
  File "/ssd5/jiangjiajun/tvm/python/tvm/_ffi/_ctypes/object.py", line 136, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/ssd5/jiangjiajun/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  5: TVMFuncCall
  4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::ObjectRef (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)>::AssignTypedLambda<tvm::runtime::ObjectRef (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)>(tvm::runtime::ObjectRef (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  3: tvm::LoadJSON(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
  2: tvm::ReflectionVTable::VisitAttrs(tvm::runtime::Object*, tvm::AttrVisitor*) const
  1: tvm::FieldDependencyFinder::Visit(char const*, tvm::runtime::ObjectRef*)
  0: void tvm::FieldDependencyFinder::ParseValue<unsigned long>(char const*, unsigned long*) const
  File "/ssd5/jiangjiajun/tvm/src/node/serialization.cc", line 291
JSONReader: cannot find field axis

I’ve also tried astext and fromtext, got the same error message

Looks like the JSON loader uses the GatherAttr visitor to deserialize the gather_nd attribute (which is none), but I’m not familiar with the attribute visitor so I cannot dive into more details.

cc @tqchen

@jiangjiajun Thanks, it is fixed in Fix GatherND attribute registration by masahi · Pull Request #8269 · apache/tvm · GitHub

I’ve recently hit the same error before and reported in [AutoTVM][AutoScheduler] Add workaround to alter op layout bug in task extraction. by jwfromm · Pull Request #8143 · apache/tvm · GitHub, the error happened during deepcopy of a model that has gather_nd. At that time I had no idea what was the issue, but your minimal repro helped me find this bug. Thanks!

2 Likes

Thanks! This Problem was solved with the newest code