Different output for large YOLO ONNX model in Python (correct) and C++ (incorrect)

I have a relatively large ONNX model (~80MB) created using YOLO that I am attempting to use for object recognition in 512x512 images. Using python to compile and run the model works perfectly, using both Relay and tvmc. However, when attempting to run the same compiled model in C++, the output I get behaves as though no input were given, or as though a blank image were fed in. Here is the abbreviated code that is working in python to compile and run the model.

model = onnx.load(model_file_path);
shape_dict = {"data": (1, 3, 512, 512)}
mod, params = relay.frontend.from_onnx(model, shape_dict)
compiled = relay.build(mod, tvm.target.Target("llvm"), params=params)
compiled.export_library(compiled_model)

module = tvm.runtime.load_module(compiled_model, 'so')
dev = tvm.device(str(target), 0)
mod = graph_executor.GraphModule(module["default"](dev))

mod.set_input('data', img)
mod.run()
out = mod.get_output(0).numpy()

And here is the abbreviated code that successfully runs the model without error, but has an incorrect output:

    tvm::runtime::Module mod_syslib = tvm::runtime::Module::LoadFromFile("<path goes here>/compiled_model_2.dylib");
    DLDevice dev{kDLCPU, 0};
    tvm::runtime::Module gmod = mod_syslib.GetFunction("default")(dev);

    tvm::runtime::PackedFunc set_input = gmod.GetFunction("set_input");
    tvm::runtime::PackedFunc get_output = gmod.GetFunction("get_output");
    tvm::runtime::PackedFunc run = gmod.GetFunction("run");

    tvm::runtime::NDArray x = tvm::runtime::NDArray::Empty({1, 3, 512, 512}, DLDataType{kDLFloat, 32, 1}, dev);
    tvm::runtime::NDArray y = tvm::runtime::NDArray::Empty({1, 26, 15, 4096}, DLDataType{kDLFloat, 32, 1}, dev);

    cv::Mat image_mat = cv::imread("<Path goes here>/pic.jpg");
    cv::Mat image_form = cv::dnn::blobFromImage(image_mat, 1.0/255.0, cv::Size(512, 512), cv::Scalar(0,0,0), false);
    x.CopyFromBytes(image_form.data, image_form.total()*image_form.channels()*image_form.elemSize());

    set_input("data", x);
    run();
    get_output(0, y);

    const int *inter_dim = new int[4] {26, 15, 64, 64};
    cv::Mat out_data(y->ndim, inter_dim, CV_32FC1);
    y.CopyToBytes(out_data.data, inter_dim[0] * inter_dim[1] * inter_dim[2] * inter_dim[3] * sizeof(float));
  1. I have also tried running the model in C++ by compiling the model using tvmc and creating the graph_executor using the following method. I was able to get it to run without error, but it had the same incorrect output as the previous method:
tvm::runtime::Module gmod = (*tvm::runtime::Registry::Get("tvm.graph_executor.create"))(json_data, mod_syslib, device_type, device_id);
  1. Iā€™ve verified that the parameters are present in the C++ model, and appear to match what is present in the compilation of the model in the Python script. The only other input is the ā€˜dataā€™ input.
  2. I have also verified that the input is indeed being set, as scaling the image color from 0-255 instead of what it is supposed to be, 0-1, heavily affects the appearance of the output, though it is just as incorrect.
  3. I have attempted to compile the model by changing the target from ā€œllvmā€ to ā€œllvm -link-paramsā€ as suggested here, but while this compiles successfully, it causes both the python and C++ programs to crash once they attempt to create the graph executor using the newly compiled model. Here is the stack trace when this happens:

This is being compiled on an x86 MacBook Pro with macOS Big Sur Version 11.6.

The big mystery in my mind is why the Python interface is able to run perfectly, whereas I cannot seem to get the C++ implementation to give the correct output despite using the same compiled model and inputs, no matter what method I use.

Please let me know if any more information should be added. Iā€™ve spent about 3 weeks or so troubleshooting this, so any help or suggestions are massively appreciated. Thank you very much for your time and attention.

I wonder if cv::Mat is non-contiguous or if its data is in a different ordering from what you expect. Can you 1. trying running the model with the input being all zeros? and 2. Try manually copying each element using cv::Mat::at?

Perhaps the image channels are in BGR? blobFromImage has a swapRB argument you can maybe try? I know openCV has images be BGR a lot of the time.

I donā€™t know how you did the python stuff, but if you used PIL, images are usually in RGB form which might explain the difference.

Good thinking. Running this line of code:

std::cout << "Continuous? " << std::to_string(image_form.isContinuous()) << std::endl;

Returns true, which I believe indicates that the blob should be contiguous.

This is the function used to preprocess images in the working python script, which then get fed directly into the model.

def preprocess(img_bgr, input_size):
    # first input is 'BGR' image, second is registered Ir image, both in original resolution
    # returns tensor 'frame' ready to run inference model on
    input_image_height = input_size[0]
    input_image_width = input_size[1]
    img_bgr = cv2.resize(img_bgr, dsize=(input_image_height, input_image_width), interpolation=cv2.INTER_LINEAR)  #resize B,G,R image to network input size with linear interpolation
    frame = img_bgr.transpose(2, 0, 1)  #reorder to {C,H,W} order
    frame = np.reshape(frame, (1, 3, input_image_height, input_image_width))  #reshape for 4d array for network input
    frame = (frame.astype(np.float32))/255.0   #rescale values from 0 to 1
    return frame

Images are loaded from storage using cv.imread() The blobFromImage function appears to match the result of the python preprocessing function, but Iā€™d be thrilled if I missed something and the only issue was with the input.

Changing the blobFromImage function to

cv::Mat image_form = cv::dnn::blobFromImage(image_mat, 0, cv::Size(512, 512), cv::Scalar(0,0,0), false);

in order to set the input to all zeros does appear to have an effect on the model output, lowering the maximum confidence of the outputs from below 20% to below 10%. Manually copying the image data using this (I had trouble using cv::Mat::at on the blob for some reason)

    int *in_loc = new int[4] {1,0,0,0};
    float *in_data = reinterpret_cast<float *>(x->data);
    float *in_img = reinterpret_cast<float *>(image_form.data);
    for (in_loc[1] = 0; in_loc[1] < image_form.size[1]; in_loc[1]++) {
        for (in_loc[2] = 0; in_loc[2] < image_form.size[2]; in_loc[2]++) {
            for (in_loc[3] = 0; in_loc[3] < image_form.size[3]; in_loc[3]++) {
                in_data[in_loc[1]*512*512 + in_loc[2]*512 + in_loc[3]] = in_img[in_loc[1]*512*512 + in_loc[2]*512 + in_loc[3]];
            }
        }
    }

Also does not appear to have any effect on the output.

Also an excellent idea. From the python code and the creator of the model, I believe the model was designed to accept BGR input, rather than RGB. I did try swapping those channels anyway in C++ using the swapRB parameter, but it does not appear to affect the output. The importing and processing of test images is all handled in opencv in both C++ and Python, though I did try using PIL in Python at some point and I believe was able to get a good output after changing the RGB channels to BGR.

When you use all zeros as the input, does the output match the python code?

I suggest using cv::Mat::at (or cv::Mat::operator()) because it will give you the correct datapoint regards of the order that the data is stored in memory.

Yes, the outputs from running the model in Python and in C++ look the same when the data input is set to all zeros. Iā€™ve modified the code to use the at function:

    int *in_loc = new int[4] {0};
    const int* dims = new int[4] {image_form.size[0], image_form.size[1], image_form.size[2], image_form.size[3]};
    float *in_data_raw = reinterpret_cast<float *>(x->data);
    cv::Mat in_data{4, dims, CV_32FC1, x->data, 0};
    for (in_loc[1] = 0; in_loc[1] < image_form.size[1]; in_loc[1]++) {
        for (in_loc[2] = 0; in_loc[2] < image_form.size[2]; in_loc[2]++) {
            for (in_loc[3] = 0; in_loc[3] < image_form.size[3]; in_loc[3]++) {
                in_data.at<float>(in_loc) = image_form.at<float>(in_loc);
                // Also tried this, with same result
                //in_data_raw[in_loc[1]*512*512 + in_loc[2]*512 + in_loc[3]] = image_form.at<float>(in_loc);
            }
        }
    }

But it appears to have given the same result, with confidence levels all below 20%.

hi @ryan-csm,

Could you try running with -link-params as was suggested on the other post? To do so:

  1. add -link-params to the Target string.
  2. if youā€™re just doing the following:
    compiled = tvm.relay.build(...)
    compiled.export_library("foo.so")
    
    the GraphExecutorFactory will shoot you in the foot at runtime as you observed. Instead, you need to export like so: compiled.get_lib().export_library("foo.so"). you then also need to export the Graph JSON separately: open("graph.json", "w").write(compiled.get_json()).
  3. then, at runtime, load the library and JSON and instantiate. iā€™ll be interested to see what you get.

filed https://github.com/apache/tvm/issues/9570 to track the segfault you experienced.

This code doesnā€™t look correct. y has shape {1, 26, 15, 4096}, but your putting it into a {26, 15, 64, 64} matrix.

It does look odd, but itā€™s based on the output being a contiguous block of memory, with the array dimensions defined just by step sizes. Since the array dimensions are in the same order, I believe I can ignore the first array index 1, since that just contains the data, and the 4096 part is technically already an array with the structure 64x64 in that 1d block of memory, and since those dimensions are adjacent, they should be able to be expanded in this way. These are quirks of the model Iā€™m trying to use, which Iā€™ve worked out with the creator.

It doesnā€™t seem like the compiled model has get_json as a member. Is there another way to get this data?

class ExecutorFactoryModule:
    """Common interface for executor factory modules
    This class describes the common API of different
    factory modules
    """

    @abstractmethod
    def get_executor_config(self):
        """Return the internal configuration the executor uses to execute the network"""
        raise NotImplementedError

    @abstractmethod
    def get_params(self):
        """Return the compiled parameters."""
        raise NotImplementedError

    @abstractmethod
    def get_lib(self):
        """Return the generated library"""
        raise NotImplementedError

    def __getitem__(self, item):
        return self.module.__getitem__(item)

    def __iter__(self):
        warnings.warn(
            "legacy graph executor behavior of producing json / lib / params will be "
            "removed in the next release."
            " Please see documents of tvm.contrib.graph_executor.GraphModule for the "
            " new recommended usage.",
            DeprecationWarning,
            2,
        )
        return self

    def __next__(self):
        if self.iter_cnt > 2:
            raise StopIteration

        objs = [self.get_executor_config(), self.lib, self.params]
        obj = objs[self.iter_cnt]
        self.iter_cnt += 1
        return obj

I attempted to use it, but it produced an error.

Ah, never mind, using get_executor_config appears to work.

Okay, I was able to compile and save the model, params, and json using this method, and it appears that I am now able to load the module into C++ and create the graph executor using the method

tvm::runtime::Module mod_syslib = tvm::runtime::Module::LoadFromFile("<path goes here>/mod_2.so");
tvm::runtime::Module gmod = (*tvm::runtime::Registry::Get("tvm.graph_executor.create"))(json_data, mod_syslib, device_type, device_id);

I can also set the input, run the model, and extract the output successfully, but attempting to load the parameters into the model using this method crashes the program.

std::ifstream params_in("<path goes here>/mod.params", std::ios::binary);
std::string params_data((std::istreambuf_iterator<char>(params_in)), std::istreambuf_iterator<char>());
params_in.close();
TVMByteArray params_arr;
params_arr.data = params_data.c_str();
params_arr.size = params_data.length();
tvm::runtime::PackedFunc load_params = gmod.GetFunction("load_params");
load_params(params_arr);

With the following stack trace:

(The error occurred on the load_params function call)

yeah so you donā€™t need to load params in this caseā€“they are baked into the program (in .text) and canā€™t be changed.

Got it. Running the model that was compiled using this method works, but still has the same output, with all confidence levels below 20%.

ok. it seems like perhaps youā€™re doing something in Python thatā€™s different than what youā€™re doing in the other use cases. a couple other suggestions:

  1. try running inference in a separate Python script from the one you use to build. e.g. put this stuff in a separate Python script and see if you can reproduce the correct or incorrect output.
    module = tvm.runtime.load_module(compiled_model, 'so')
    dev = tvm.device(str(target), 0)
    mod = graph_executor.GraphModule(module["default"](dev))
    
    mod.set_input('data', img)
    mod.run()
    out = mod.get_output(0).numpy()
    
  2. try using the DebugExecutor (e.g. mod = debug_executor.GraphModuleDebug(compiled_model["debug_create"])). this should provide you layer-by-layer analysis. perhaps there is stack corruption happening outside of Python, given you have over 300 parameters?
  3. try running your model in C++ without QT linked in.

Okay, interesting results.

  1. I was able to run a similar python script successfully with good output on a separate Mac computer.
  2. I got some interesting results from this. I compared the dimensions and outputs of each node in the graph when running in C++ and python, and found them to be identical until it gets almost to the end at node 370. The column on the left are the node numbers and dimensions from C++, and the right is from the Python script. The rightmost column is the op for each node.
  3. I was able to compile and run the C++ project using gcc and clang, but there was no change in the output from when I ran it in QT.

Here is the netron view of the last nodes in the original onnx model that TVM is compiling, if that helps.

Do these provide any insight? The output dimensions and content from the C++ and python models are consistent, but I have no idea what may be causing the difference at the end.

I should note, those nodes are the only Softmax/Transpose nodes in the model.

@ryan-csm is it possible for you to isolate the problem in a smaller subgraph you could share? the weights likely donā€™t matter too much here (e.g. random weights would probably demonstrate the issue after a few tries). alternatively, is it possible you could ā€œbisectā€ by slowly peeling layers from the end of the model until the two environments agree all the time?

Hey Ryan, this is very interesting, if Iā€™m understanding the shapes donā€™t even match for node 370 between python and c++?

So Iā€™m guessing the input to node 370 is a tensor of shape [1, 2, 15, 4096] from the split operator. Very interestingly, the permutation of the transpose appears to be different between the two and the shape also. For each node do you also have the input nodes?

Alternatively you can just share the code and model you are running if possible and I can take a closer look.

Okay, sorry for the delay. It looks like part of the issue may have been exacerbated by the debug graph executor. Using the new compilation method adding -link-params, and after some adjustments in the C++ implementation to keep the creation of the graph executor from causing a crash, running the model now appears to be working the same in both C++ and Python. What works is compiling the model using this method in Python:

model = onnx.load(model_file_path)
shape_dict = {"data": (1, 3, 512, 512)}
mod, params = relay.frontend.from_onnx(model, shape_dict)
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
      target = tvm.target.Target("llvm -link-params")
      compiled = relay.build(mod, target=target, params=params)
file = open("./mod.json", "w")
file.write(compiled.get_executor_config())
file.close()
compiled.get_lib().export_library(compiled_model)

And running it in C++ using:

tvm::runtime::Module mod_syslib = tvm::runtime::Module::LoadFromFile(model_file_path);
std::ifstream json_in(json_file_path, std::ios::in);
std::string json_data((std::istreambuf_iterator<char>(json_in)),std::istreambuf_iterator<char>());
json_in.close();

int64_t device_type = kDLCPU;
int64_t device_id = 0;

tvm::runtime::Module gmod = (*tvm::runtime::Registry::Get("tvm.graph_executor.create"))(json_data, mod_syslib, device_type, device_id);

tvm::runtime::PackedFunc set_input = gmod.GetFunction("set_input");
tvm::runtime::PackedFunc get_output = gmod.GetFunction("get_output");
tvm::runtime::PackedFunc run = gmod.GetFunction("run");

tvm::runtime::NDArray x = tvm::runtime::NDArray::Empty({1, 3, 512, 512}, DLDataType{kDLFloat, 32, 1}, dev);
tvm::runtime::NDArray y = tvm::runtime::NDArray::Empty({1, 26, 15, 4096}, DLDataType{kDLFloat, 32, 1}, dev);

cv::Mat image_mat = cv::imread(image_file_path);
cv::Mat image_form = cv::dnn::blobFromImage(image_mat, 1.0/255.0, cv::Size(512, 512), cv::Scalar(0,0,0), false);

x.CopyFromBytes(image_form.data, image_form.total()*sizeof(float));
TVMSynchronize(dev.device_type, dev.device_id, nullptr)
set_input("data", x);
TVMSynchronize(dev.device_type, dev.device_id, nullptr)
run();
TVMSynchronize(dev.device_type, dev.device_id, nullptr)
get_output(0, y);

Using the debug graph executor seems to have produced a different output between the two implementations, but using the regular graph executor now produces the same result. Unfortunately, I canā€™t post the model at the moment, but I may try to get a stripped down version to see if these issues can be reproduced. Is there anywhere I can learn more about the differences between compilation and model loading methods, and what the contents of those different models are? I cannot thank you all enough for the help you have provided. There are a few tasks I need to dedicate time to at the moment, but Iā€™ll try to update this post once I have time to learn more about what caused the issue.

1 Like