The following code should be a reproduction of my bug:
#[test]
fn relay_bug() {
let shape = vec![1, 2, 3];
let shape_prim_expr_tvm_array: Array<PrimExpr> = tvm::runtime::array::Array::from_vec(
shape
.iter()
.map(|i| IntImm::from(i32::try_from(*i).unwrap()).upcast())
.collect(),
)
.unwrap();
let body: Expr = Constant::new(
NDArray::from_rust_ndarray(
&ndarray12::ArrayD::<f32>::zeros(shape),
Device::cpu(0),
DataType::float32(),
)
.unwrap(),
Span::null(),
)
.upcast();
let inner_func: Expr = Function::new(
tvm::runtime::array::Array::from_vec(vec![]).unwrap(), //inner_args).unwrap(),
body,
TensorType::new(
shape_prim_expr_tvm_array.clone(),
DataType::float32(),
Span::null(),
)
.upcast(),
tvm::runtime::array::Array::from_vec(vec![]).unwrap(),
)
.upcast();
let body: Expr = Call::new(
inner_func.clone(),
Array::from_vec(vec![]).unwrap(),
Attrs::null(),
tvm::runtime::array::Array::from_vec(vec![]).unwrap(),
Span::null(),
)
.upcast();
let irmodule = IRModule::from_expr(body).unwrap();
println!("{}", tvm::ir::expr::as_text(irmodule));
}
(note that ndarray12 is just my renaming of ndarray, as I’m also using another version of ndarray in my crate)
The resulting body expr seems to be malformed. When I try to pretty-print it within the TVM C++ source, I get an error. I can print the inner_func
just fine, but it seems like the Call
to inner_func
is malformed. Can someone check whether I’m forming the Call
correctly?