I am trying to write a TVM function that will index an array using values of another array, something like the example below, but I get a TVMError
that the index should be an int. Is there a way to do this?
import tvm
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n), name='A')
B = tvm.placeholder((m), name='B')
C = tvm.compute((m,), lambda i: A[B[i]], name="C") # <<== indexing array by values of another array
s = tvm.create_schedule(C.op)
f = tvm.build(s, [A, B], None)
TVMError Traceback (most recent call last)
<ipython-input-7-d9438b269026> in <module>
5 m = tvm.var("m")
6 B = tvm.placeholder((m), name='B')
----> 7 C = tvm.compute((m,), lambda i: A[B[i]], name="C")
8 s = tvm.create_schedule(C.op)
9 f = tvm.build(s, [A, B], None)
/usr/tvm/python/tvm/api.py in compute(shape, fcompute, name, tag, attrs)
329 if not isinstance(body, (list, tuple)):
330 body = [body]
--> 331 body = convert(body)
332 op_node = _api_internal._ComputeOp(
333 name, tag, attrs, dim_var, body)
/usr/tvm/python/tvm/api.py in convert(value)
137 return _convert_tvm_func(value)
138
--> 139 return _convert_to_node(value)
140
141
/usr/tvm/python/tvm/_ffi/node_generic.py in convert_to_node(value)
77 return _api_internal._str(value)
78 if isinstance(value, (list, tuple)):
---> 79 value = [convert_to_node(x) for x in value]
80 return _api_internal._Array(*value)
81 if isinstance(value, dict):
/usr/tvm/python/tvm/_ffi/node_generic.py in <listcomp>(.0)
77 return _api_internal._str(value)
78 if isinstance(value, (list, tuple)):
---> 79 value = [convert_to_node(x) for x in value]
80 return _api_internal._Array(*value)
81 if isinstance(value, dict):
/usr/tvm/python/tvm/_ffi/node_generic.py in convert_to_node(value)
89 return _api_internal._Map(*vlist)
90 if isinstance(value, NodeGeneric):
---> 91 return value.asnode()
92 if value is None:
93 return None
/usr/tvm/python/tvm/tensor.py in asnode(self)
40 def asnode(self):
41 """Convert slice to node."""
---> 42 return self.tensor(*self.indices)
43
44 @property
/usr/tvm/python/tvm/tensor.py in __call__(self, *indices)
75 return _make.Call(self.dtype, self.op.name,
76 args, _expr.Call.Halide,
---> 77 self.op, self.value_index)
78
79 def __getitem__(self, indices):
/usr/tvm/python/tvm/_ffi/_ctypes/function.py in __call__(self, *args)
205 self.handle, values, tcodes, ctypes.c_int(num_args),
206 ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
--> 207 raise get_last_ffi_error()
208 _ = temp_args
209 _ = args
TVMError: Traceback (most recent call last):
[bt] (3) /usr/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f905993fe21]
[bt] (2) /usr/tvm/build/libtvm.so(+0x452803) [0x7f905911e803]
[bt] (1) /usr/tvm/build/libtvm.so(tvm::ir::Call::make(tvm::DataType, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::Expr, void>, tvm::ir::Call::CallType, tvm::ir::FunctionRef, int)+0x3ca) [0x7f90592d51da]
[bt] (0) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f90590f9d12]
File "/usr/tvm/src/lang/ir.cc", line 207
TVMError: Check failed: args[i].type().is_int():