TOPI reshape operator directly cast IntImmNode
to int32
, but in my application I need int64
.
inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<PrimExpr> target_shape;
for (const auto& ele : newshape) {
if (ele.as<IntImmNode>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
}
}
if (is_empty_shape(target_shape)) {
return compute(
target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
} else {
return compute(
target_shape,
[&](const Array<Var>& indices) {
return x(UnravelIndex(
RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
},
name, tag);
}
}
Can we make it more flexible so that we can set the type by passing IntImm
in python?