Hey all,
I am working on a model that is written in PyTorch and exported to ONNX. During relay.build
with opt level = 1, I ran into a type mismatch. The error does not occur when opt level = 0:
TypeError: Check failed: a.dtype() == b.dtype(): mismatched types
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(7, 1, 32, 1536), float32], Primitive=1) -> Tensor[(7, 1, 1536), float32] {
%0 = reshape(%p0, newshape=[7, 1, -1, 1536]) /* ty=Tensor[(7, 1, 32, 1536), float32] */;
%1 = take(%0, 0 /* ty=int64 */, axis=2) /* ty=Tensor[(7, 1, 1536), float32] */;
reshape(%1, newshape=[-1, 1, 1536]) /* ty=Tensor[(7, 1, 1536), float32] */
}
This is in expr.h
, in the function BinaryOpNode::make
. The issue seems to be coming from the fact that the second input of take
is an int64. I was able to workaround the issue by casting this argument to int32 in the ONNX frontend, but that solution really isn’t ideal.
I traced the take
call back to this line of PyTorch:
new_data = all_data[:,:,0]
This is not the first time I’ve seen small type mismatch errors like this when importing from ONNX, and I don’t think that making a change on the PyTorch side is the right way to fix it (I’m not even sure there is a way to fix it, given the simplicitly of this line).
Is there a way we can fix this in TVM? For example, can we cast the int64 to int32? Also, why does it only show up when opt level = 1? Presumably this has to do with fusion, but I haven’t been able to figure out the root cause.
Thanks!