I am new to using Tensor Expression. I want to write compute operation for PyTorch’s EmbeddingBag operation using TE. I am able to create schedule for following compute expression but the build is failing:
m = te.var(“m”)
n = te.var(“n”)
b = te.var(“b”)
d = te.var(“d”)
Data = te.placeholder((m,d), dtype=“int8”, name=“Data”)
Indices = te.placeholder((n,), dtype=“int32”, name=“Indices”)
Offsets = te.placeholder((b,), dtype=“int32”, name=“Offsets”)
Here is the Build issue:
TVMError: Traceback (most recent call last): 15: TVMFuncCall … 2: tvm::operator<=(tvm::PrimExpr, tvm::PrimExpr) 1: tvm::less_equal(tvm::PrimExpr, tvm::PrimExpr, tvm::Span) 0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span) File “/home/tvm/src/tir/op/op.cc”, line 163 TVMError: Cannot match type int32 vs handle
Since Embeddingbag requires offset(runtime variable) to slice indices and use “take” operation, can we write compute using TE/TIR or shall I need to use Hybrid Script for this operation?