What I’ve done:
- PyTorch Model -> ONNX:
Here is my model with inputs of (1, 1, 210, 405) tensors.
class CNN(nn.Module):
# for 210 x 405
def __init__(self, num_classes=2):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 2, kernel_size=5, stride=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(2, 4, kernel_size=5, stride=2),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(4, 4, kernel_size=5, stride=2),
nn.ReLU(inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(1012, num_classes),
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.fc(out)
return out
This .onnx is good as I tried OpenCV’s cv::readNetFromONNX and it works very good with surprising performance.
However, when I was trying to compile my .onnx file via TVM:
- ONNX->TVM
import numpy as np
from PIL import Image
import tvm.relay
import onnx
dtype = 'float32'
img = Image.open("../symbol_dataset_for_training/0-0.bmp").convert('L')
# img = np.array(img).transpose((2, 0, 1)).astype('float32') // 如果一开始就按OpenCV的格式测试
img = np.array(img).astype(dtype)
img = img/255.0
x = img[np.newaxis, np.newaxis, :]
print(x.shape)
onnx_model = onnx.load("../onnx_test/symbol_apply.onnx")
opt_level = 3 # 优化等级
target = tvm.target.create("llvm -mcpu=haswell")
shape_dict = {'0': (1, 1, 405, 210)}
sym, params = tvm.relay.frontend.from_onnx(onnx_model, shape_dict)
with tvm.relay.build_config(opt_level=1):
intrp = tvm.relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)
func = intrp.evaluate(sym)
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
max_index = tvm_output.argmax()
print(max_index)
I got some sad logs:
Error(s) have occurred. We have annotated the program with them:
In `main`:
v0.0.1
fn (%v0: Tensor[(1, 1, 405, 210), float32]) {
%0 = nn.conv2d(%v0, meta[relay.Constant][0] /* */, kernel_size=[5, 5]) /* */
%1 = nn.bias_add(%0, meta[relay.Constant][1] /* */) /* */
%2 = nn.relu(%1) /* */
%3 = nn.max_pool2d(%2, pool_size=[2, 2], strides=[2, 2]) /* */
%4 = nn.conv2d(%3, meta[relay.Constant][2] /* */, strides=[2, 2], kernel_size=[5, 5]) /* */
%5 = nn.bias_add(%4, meta[relay.Constant][3] /* */) /* */
%6 = nn.max_pool2d(%5, pool_size=[2, 2], strides=[2, 2]) /* */
%7 = nn.relu(%6) /* */
%8 = nn.conv2d(%7, meta[relay.Constant][4] /* */, strides=[2, 2], kernel_size=[5, 5]) /* */
%9 = nn.bias_add(%8, meta[relay.Constant][5] /* */) /* */
%10 = nn.relu(%9) /* */
%11 = shape_of(%10, dtype="int32") /* */
%12 = take(%11, int64(0), axis=0) /* */
%13 = expand_dims(%12, axis=0) /* */
%14 = expand_dims(int64(-1), axis=0) /* */
%15 = (%13, %14)
concatenate(%15) /* an internal invariant was violated while typechecking your program [00:28:37] /Users/ganler-mac/Desktop/local store/test/tvm/src/relay/op/tensor/transform.cc:204: Check failed: e_dtype == dtype (int64 vs. int32) : relay.concatenate requires all tensors have the same dtype
; */
}
// meta data omitted. you can use show_meta_data=True to include meta data
I don’t where the “int64” comes from as relay.concatenate requires all tensors have the same dtype.
Can anyone help me on this issue?
Thx a lot!