Hi all,
I am trying to construct Relay expressions using the Python API. I am using the following code
‘’’
import tvm
from tvm import relay
from tvm.relay import var, Var, Function, Let, TensorType, Module
from tvm.relay.prelude import Prelude
from network import Network
class GatherNDTest():
def __init__(self, name="f", **kwargs):
self.mod = Module()
self.p = Prelude(self.mod)
self.inputs = []
self.f = relay.GlobalVar(name)
self.mod[self.f] = relay.Function(self.inputs, self.build_impl(**kwargs), None)
def input(self, i):
self.inputs.append(i)
return i
def build_impl(self, vocab_size, embedding_size):
# Input declarations
embeddings_type = TensorType(shape = (vocab_size, embedding_size), dtype = "float32")
embeddings = self.input(Var("embeddings", embeddings_type))
word_index_type = TensorType(shape = (), dtype = "int32")
word_index = self.input(Var("word_index", word_index_type))
index_expr = tvm.relay.gather_nd(embeddings, word_index)
return index_expr
mod = GatherNDTest(name = "gather_test", vocab_size = 200, embedding_size = 400).mod
print(mod)
‘’’
The backtrace from GDB is here (https://pastebin.com/cvbtn1tm ). I am using commit e3eff20d8a91c2b986548f780493a8e30b0c2b7a on master.
In any case, I cannot figure out how to pass arguments to gather_nd, so I can index the embeddings tensor to get a single word embedding. Thanks!
hgt312
December 4, 2019, 4:29am
2
Try to change the shape of word_index
?
Do you mean making it a tensor instead of the scalar it is now? I tried changing the shape to (1, ) instead of the empty tuple it is now and I get the following error.
‘’’
Traceback (most recent call last):
File "test.py", line 31, in <module>
mod = GatherNDTest(name = "gather_test", vocab_size = 200, embedding_size = 400).mod
File "test.py", line 14, in __init__
self.mod[self.f] = relay.Function(self.inputs, self.build_impl(**kwargs), None)
File "/home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/python/tvm/relay/module.py", line 85, in __setitem__
return self._add(var, val)
File "/home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/python/tvm/relay/module.py", line 94, in _add
_module.Module_Add(self, var, val, update)
File "/home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*) const+0x151) [0x7fd7121430c3]
[bt] (7) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*)#5}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*)+0x25) [0x7fd71214279e]
[bt] (6) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*)#5}::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>*) const+0x3e) [0x7fd712142772]
[bt] (5) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::WellFormedChecker::VisitExpr_(tvm::relay::FunctionNode const*)+0xd5) [0x7fd7122a9b29]
[bt] (4) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::WellFormedChecker::CheckWellFormed(tvm::relay::Expr const&)+0x23) [0x7fd7122a9d05]
[bt] (3) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::WellFormedChecker::VisitExpr(tvm::relay::Expr const&)+0x4f) [0x7fd7122a9cdf]
[bt] (2) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::relay::Expr const&)+0xa3) [0x7fd7123a3f53]
[bt] (1) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0x98) [0x7fd712142108]
[bt] (0) /home/ppf/data/ppf/projects/projects/rnn_compilers/incubator-tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4a) [0x7fd711a5527c]
File "/home/ppf/Documents/projects/rnn_compilers/incubator-tvm/include/tvm/relay/expr_functor.h", line 92
TVMError: Check failed: n.defined():
‘’’