Hi, I’ve met a strange problem. For my work, I need to save IRModule and Parameters which is constructed from custom dnn model.
Problem: Once the IRModule contains gather_nd
, the pickle file can’t be load successfully.
here is my code to save IRModule
from tvm import relay
import tvm
import pickle
from tvm.relay import expr as _expr
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape
import numpy as np
def gather_nd(data, indices):
indices_dims = len(infer_shape(indices))
indices = _op.transpose(indices, axes=[-1]+list(range(indices_dims-1)))
out = _op.gather_nd(data, indices, 0)
return out
ipt0 = relay.var('x', shape=[2, 15, 4], dtype='float32')
index = np.array([[0,5], [1, 2], [0, 7]]).astype('int64')
ipt1 = _expr.const(index)
out = gather_nd(ipt0, ipt1)
args = relay.analysis.free_vars(out)
func = relay.Function(args, out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
with open('relay.pkl', 'wb') as f:
pickle.dump(mod, f)
Error happens while loading the pickle file
import pickle
with open('relay.pkl', 'rb') as f:
mod = pickle.load(f)
error messge
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/ssd5/jiangjiajun/tvm/python/tvm/runtime/object.py", line 91, in __setstate__
self.__init_handle_by_constructor__(_ffi_node_api.LoadJSON, handle)
File "/ssd5/jiangjiajun/tvm/python/tvm/_ffi/_ctypes/object.py", line 136, in __init_handle_by_constructor__
handle = __init_by_constructor__(fconstructor, args)
File "/ssd5/jiangjiajun/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
5: TVMFuncCall
4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::ObjectRef (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)>::AssignTypedLambda<tvm::runtime::ObjectRef (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)>(tvm::runtime::ObjectRef (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
3: tvm::LoadJSON(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
2: tvm::ReflectionVTable::VisitAttrs(tvm::runtime::Object*, tvm::AttrVisitor*) const
1: tvm::FieldDependencyFinder::Visit(char const*, tvm::runtime::ObjectRef*)
0: void tvm::FieldDependencyFinder::ParseValue<unsigned long>(char const*, unsigned long*) const
File "/ssd5/jiangjiajun/tvm/src/node/serialization.cc", line 291
JSONReader: cannot find field axis