Yolo: Python (works) vs. C++ (fails)

Compiling the basic yolo function from mxnet/gluoncv used to work in 0.6.0, but since I’ve updated to 0.7.dev1 (config: cuda, cudnn, cublas, thrust), it seems like all my code has broken. As of now, compiling and testing code in python works fine, but my deployment code in C++ is broken: I keep receiving a radix_sort error that only happens when I use c++ code, but not when testing in python. Have I overlooked C++ api changes that now don’t align with the python api?

Searching around makes it seem like it may have something to do with https://github.com/apache/incubator-tvm/blob/master/src/runtime/contrib/thrust/thrust.cu#L61: thrust::sort_by_key, but I can’t tell for sure.

Tried on 1660, 1080Ti, TitanX, Tx2: all the same result for me.

Error Output C++

terminate called after throwing an instance of 'dmlc::Error'
  what():  [14:36:16] /opt/src/tvm/src/runtime/library_module.cc:78: Check failed: ret == 0 (-1 vs. 0) : radix_sort: failed on 2nd step: cudaErrorInvalidValue: invalid argument

Stack trace:
  [bt] (0) ./main(dmlc::LogMessageFatal::~LogMessageFatal()+0x61) [0x41a331]
  [bt] (1) /usr/local/lib/libtvm_runtime.so(+0x76aa3) [0x7fdb5537caa3]
  [bt] (2) /usr/local/lib/libtvm_runtime.so(+0xe15f7) [0x7fdb553e75f7]
  [bt] (3) /usr/local/lib/libtvm_runtime.so(tvm::runtime::GraphRuntime::Run()+0x47) [0x7fdb553e7687]
  [bt] (4) ./main(MinimalYolo::forward_full(cv::Mat)+0x68c) [0x417a4c]
  [bt] (5) ./main(main+0x492) [0x412dc2]
  [bt] (6) /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf0) [0x7fdb3b2ab830]
  [bt] (7) ./main(_start+0x29) [0x413169]

Python compile + test on street_small image

import numpy as np
import mxnet as mx
from tvm import relay
from gluoncv import model_zoo, data, utils
import tvm
from tvm.contrib import graph_runtime
import logging
from mxnet.gluon.data.vision import transforms

im_fname = download_testdata('https://github.com/dmlc/web-data/blob/master/' +
                         'gluoncv/detection/street_small.jpg?raw=true',
                         'street_small.jpg', module='data')

TRANSFORM_FN = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])

img = cv2.imread(im_fname)[..., ::-1]
img_t = cv2.resize(img, (256, 256))
img_t = TRANSFORM_FN(mx.nd.array(img_t))
img_n = img_t.expand_dims(0).asnumpy()

ctx = tvm.gpu(0)
target = 'cuda -libs=cudnn,cublas -model=titanx'
block = model_zoo.get_model('yolo3_mobilenet1.0_coco', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': (1,3,256,256)}, dtype='float32')
net = mod["main"]
net = relay.Function(net.params, net.body, None, net.type_params, net.attrs)
mod = tvm.IRModule.from_expr(net)

# target = tvm.target.cuda('titanx')
with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = relay.build_module.build(
        mod, target=target, params=params) 

module = graph_runtime.create(graph, lib, ctx)
module.set_input(**params)
module.set_input('data', tvm.nd.array(img_n).astype('float32')))
module.run()
output0 = module.get_output(0)
output1 = module.get_output(1)
output2 = module.get_output(2)
print(output0)

output_name = '256.yolo.cuda.titanx'
lib.export_library(
    "{}.so".format(output_name))
print('lib export success')
with open("{}.json".format(output_name), "w") as fo:
    fo.write(graph)
print("graph export success")
with open("{}.params".format(output_name), "wb") as fo:
    fo.write(relay.save_param_dict(params))
print("params export success")

C++ deploy code using compiled lib

#include <cstdio>
#include <opencv2/opencv.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui.hpp>
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <fstream>
#include <iostream>
#include <iomanip>
#include <string>
#include <sstream>
#include <map>
#include <cmath>
#include <random>

class MinimalYolo{
    private:
        std::unique_ptr<tvm::runtime::Module> detector_handle;

    public:
        std::string deploy_lib_path;
        std::string deploy_graph_path;
        std::string deploy_param_path;
        bool gpu = true;
        int device_id;// = 0;
        int dtype_code = kDLFloat;
        int dtype_bits = 32;
        int dtype_lanes = 1;
        int device_type = kDLGPU;
        int detector_width = 256;
        int detector_height = 256;
        int detector_total_input = 3 * detector_width * detector_height;
        int in_ndim = 4;
        int detector_out_ndim = 3;
        int64_t tvm_id_and_score_size[3] = {1, 100, 1};
        int64_t tvm_box_size[3] = {1, 100, 4};

        /**
         * function that reads both the yolo detector and the pose detector
         * 
        */
        MinimalYolo(std::string detector_path) {
            std::cout << "start model_config reading" << std::endl;

            std::string detector_deploy_lib_path =  detector_path + ".so";
            std::string detector_deploy_graph_path =  detector_path + ".json";
            std::string detector_deploy_param_path =  detector_path + ".params";
            tvm::runtime::Module detector_mod_syslib = tvm::runtime::Module::LoadFromFile(detector_deploy_lib_path);
            std::ifstream detector_json_in(detector_deploy_graph_path, std::ios::in);
            std::string detector_json_data((std::istreambuf_iterator<char>(detector_json_in)), std::istreambuf_iterator<char>());
            detector_json_in.close();
            tvm::runtime::Module detector_mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(detector_json_data, detector_mod_syslib,
                                                                                                  device_type, device_id);
            this->detector_handle.reset(new tvm::runtime::Module(detector_mod));
            std::ifstream detector_params_in(detector_deploy_param_path, std::ios::binary);
            std::string detector_params_data((std::istreambuf_iterator<char>(detector_params_in)), std::istreambuf_iterator<char>());
            detector_params_in.close();
            TVMByteArray detector_params_arr;
            detector_params_arr.data = detector_params_data.c_str();
            detector_params_arr.size = detector_params_data.length();
            tvm::runtime::PackedFunc detector_load_params = detector_mod.GetFunction("load_params");
            detector_load_params(detector_params_arr);
        }
      
        /**
         * \brief function to normalize an image before it's processed by the network
         * \param[in] the raw cv::mat image
         * \return the normalized version of the iamge.
         */  
        cv::Mat preprocess_image(cv::Mat frame, int width, int height, bool convert){
            cv::Size new_size = cv::Size(width, height);
            cv::Mat resized_image;
            if (convert){
              cv::Mat rgb;
              cv::cvtColor(frame, rgb,  cv::COLOR_BGR2RGB);
              cv::resize(rgb, resized_image, new_size);
            } else {
              cv::resize(frame, resized_image, new_size);
            }
            cv::Mat resized_image_floats(new_size, CV_32FC3);
            resized_image.convertTo(resized_image_floats, CV_32FC3, 1.0f/255.0f);
            cv::Mat normalized_image(new_size, CV_32FC3);
            cv::Mat mean(new_size, CV_32FC3, cv::Scalar(0.485, 0.456, 0.406));
            cv::Mat theta(new_size, CV_32FC3, cv::Scalar(0.229, 0.224, 0.225));
            cv::Mat temp;
            temp = resized_image_floats - mean;
            normalized_image = temp / theta;
            return normalized_image;
        }
        /**
         * \brief fminimal example of inference
         * \param[in] the raw cv::mat image
         */  
        void forward_full(cv::Mat frame)
        {
            std::cout << "starting function" << std::endl;
            cv::Size image_size = frame.size();
            float img_height = static_cast<float>(image_size.height);
            float img_width = static_cast<float>(image_size.width);

            int64_t in_shape[4] = {1, 3, detector_height, detector_width};
            int total_input = 3 * detector_width * detector_height;
            std::cout << "width: " << detector_width << std::endl;
            std::cout << "height: " << detector_height << std::endl;
            std::cout << "total_input: " << total_input << std::endl;
            std::cout << "device_id: " << device_id << std::endl;
            std::cout << "dtype_code: " << dtype_code << std::endl;
            std::cout << "dtype_bits: " << dtype_bits << std::endl;
            std::cout << "dtype_lanes: " << dtype_lanes << std::endl;
            std::cout << "device_type: " << device_type << std::endl;

            DLTensor *output_tensor_ids;
            DLTensor *output_tensor_scores;
            DLTensor *output_tensor_bboxes;
            DLTensor *input;
            float *data_x = (float *) malloc(total_input * sizeof(float));

            std::cout << "about to allocate info" << std::endl;
            // allocate DLTensor memory on device for all the vars needed
            TVMArrayAlloc(in_shape, in_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &input);
            TVMArrayAlloc(tvm_id_and_score_size, detector_out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &output_tensor_ids);
            TVMArrayAlloc(tvm_id_and_score_size, detector_out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &output_tensor_scores);
            TVMArrayAlloc(tvm_box_size, detector_out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &output_tensor_bboxes);
            std::cout << "allocate info finished" << std::endl;

            //copy processed image to DLTensor
            std::cout << "about to preprocess" << std::endl;
            cv::Mat processed_image = preprocess_image(frame, detector_width, detector_height, true);
            std::cout << "preprocess finished" << std::endl;
            cv::Mat split_mat[3];
            cv::split(processed_image, split_mat);
            memcpy(data_x, split_mat[2].ptr<float>(), processed_image.cols * processed_image.rows * sizeof(float));
            memcpy(data_x + processed_image.cols * processed_image.rows, split_mat[1].ptr<float>(),
                  processed_image.cols * processed_image.rows * sizeof(float));
            memcpy(data_x + processed_image.cols * processed_image.rows * 2, split_mat[0].ptr<float>(),
                  processed_image.cols * processed_image.rows * sizeof(float));
            TVMArrayCopyFromBytes(input, data_x, total_input * sizeof(float));
            std::cout << "TVMArrayCopyFromBytes finished" << std::endl;           

            // standard tvm module run
            tvm::runtime::Module *mod = (tvm::runtime::Module *) detector_handle.get();
            tvm::runtime::PackedFunc set_input = mod->GetFunction("set_input");
            set_input("data", input);
            tvm::runtime::PackedFunc run = mod->GetFunction("run");
            run();
            tvm::runtime::PackedFunc get_output = mod->GetFunction("get_output");
            std::cout << "run/getoutput/setinput finished" << std::endl;
  
            // https://github.com/apache/incubator-tvm/issues/979?from=timeline
            TVMSynchronize(device_type, device_id, nullptr);
            get_output(0, output_tensor_ids);
            get_output(1, output_tensor_scores);
            get_output(2, output_tensor_bboxes);
            std::cout << "TVMSynchronize finished" << std::endl;  
            TVMArrayFree(input);
            TVMArrayFree(output_tensor_ids);
            TVMArrayFree(output_tensor_scores);
            TVMArrayFree(output_tensor_bboxes);
            input = nullptr;
            output_tensor_ids = nullptr;
            output_tensor_scores = nullptr;
            output_tensor_bboxes = nullptr;
            free(data_x);
            data_x = nullptr;
        }
};

int main(int argc, char** argv)
{   
    cv::Mat raw_image;
    raw_image = cv::imread("street_small.jpg");
    MinimalYolo yolo("256.yolo.cuda.titanx");
    yolo.forward_full(raw_image);
}

I edited thrust.cu to debug the output of both the c++ version and python versions:

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
                DLTensor* out_values,
                DLTensor* out_indices,
                bool is_ascend,
                const std::function<int(int)> &get_sort_len) {
  thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
  thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
  thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

  int n_values = input->shape[input->ndim - 1];
  int n_iter = 1;
  for (int i = 0; i < input->ndim - 1; ++i) {
    n_iter *= input->shape[i];
  }
  std::cout << "thrust: is_ascend: " << is_ascend << std::endl;
  std::cout << "thrust: n_iter: " << n_iter << std::endl;
  std::cout << "thrust: starting to copy" << std::endl;
  thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
  std::cout << "thrust: copy successful" << std::endl;
  std::cout << "thrust: starting loop" << std::endl;
  std::cout << "thrust: values_ptr: " << values_ptr << std::endl;
  for (int i = 0 ; i < n_iter; ++i) {
    n_values = get_sort_len(i);
    thrust::sequence(indices_ptr, indices_ptr + n_values);
    std::cout << "thrust: sequence call sucessful for iter: " << i << std::endl;
    if (is_ascend) {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
      std::cout << "thrust: sort_by_key is_ascend successful for iter: "<< i << std::endl;
    } else {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
                          thrust::greater<DataType>());
      std::cout << "thrust: sort_by_key successful for iter: "<< i << std::endl;
    }
    values_ptr += n_values;
    indices_ptr += n_values;
  }
}

Python output

thrust: float call: float32
thrust: is_ascend: 0
thrust: n_iter: 1
thrust: starting to copy
thrust: copy successful
thrust: starting loop
thrust: values_ptr: 0xb10bc0000
thrust: sequence call sucessful for iter: 0
thrust: sort_by_key successful for iter: 0
Time Cost :  0.21575212478637695
=========================
thrust: float call: float32
thrust: is_ascend: 0
thrust: n_iter: 1
thrust: starting to copy
thrust: copy successful
thrust: starting loop
thrust: values_ptr: 0xb10d00000
thrust: sequence call sucessful for iter: 0
thrust: sort_by_key successful for iter: 0
Time Cost :  0.00695347785949707
=========================
thrust: float call: float32
thrust: is_ascend: 0
thrust: n_iter: 1
thrust: starting to copy
thrust: copy successful
thrust: starting loop
thrust: values_ptr: 0xb10d00000
thrust: sequence call sucessful for iter: 0
thrust: sort_by_key successful for iter: 0
Time Cost :  0.006746768951416016

C++ output

thrust: float call: float32
thrust: is_ascend: 0
thrust: n_iter: 1
thrust: starting to copy
thrust: copy successful
thrust: starting loop
thrust: values_ptr: 0xb0e7a0000
thrust: sequence call sucessful for iter: 0
terminate called after throwing an instance of 'dmlc::Error'
  what():  [10:24:47] /home/mkrzus/github/tvm-latest/src/runtime/library_module.cc:78: Check failed: ret == 0 (-1 vs. 0) : radix_sort: failed on 2nd step: cudaErrorInvalidValue: invalid argument

Stack trace:
  [bt] (0) ./main(dmlc::LogMessageFatal::~LogMessageFatal()+0x67) [0x40d957]
  [bt] (1) /usr/local/lib/libtvm_runtime.so(+0x7a2d3) [0x7fbc6f1762d3]
  [bt] (2) /usr/local/lib/libtvm_runtime.so(+0xef097) [0x7fbc6f1eb097]
  [bt] (3) /usr/local/lib/libtvm_runtime.so(tvm::runtime::GraphRuntime::Run()+0x47) [0x7fbc6f1eb127]
  [bt] (4) ./main(MinimalYolo::forward_full(cv::Mat)+0xafb) [0x40baeb]
  [bt] (5) ./main(main+0x527) [0x408817]
  [bt] (6) /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf0) [0x7fbc55812830]
  [bt] (7) ./main(_start+0x29) [0x408b69]

what I don’t quite understand is why the c++ graphruntime and the python versions are different? I’ve tried adding the CUDA_LAUNCH_BLOCKING=1 flag, but still get this error on the c++ version.

   thrust: sequence call sucessful for iter: 0
terminate called after throwing an instance of 'dmlc::Error'
  what():  [10:11:08] /home/mkrzus/github/tvm-latest/src/runtime/library_module.cc:78: Check failed: ret == 0 (-1 vs. 0) : radix_sort: failed on 2nd step: cudaErrorIllegalAddress: an illegal memory access was encountered

I’ve tried playing around with various CMAKE flags/settings in order to ensure that thrust operates similarly to the tvm library, but it doesn’t seem like that makes a difference either.

Adding a save function to the raw pointer underneath the thrust sorting module reveals almost no difference between DLTensor* input in the python graph runtime and the c++ graph runtime. Is there a difference between the ways in which the graph-runtime is constructed with respect to calling CUDA for python vs. c++? In other tvm versions prior to commit 38118befc0a7e8a3db87d652b30a9369abb60363, I’m simply not having this problem (it’s slower than thrust, but it’s not bad). However, with thrust, this is where I’m at. Any help would be greatly appreciated.

class Mat{
public:
    float* m_data;
    int m_rows, m_cols, m_channels;

public:
    Mat(int cols, int rows, int channels){
        m_rows = rows;
        m_cols = cols;
        m_channels = channels;
        int size = channels * rows * cols * sizeof(float);
        m_data = (float*)malloc(size);
        memset((void *)m_data, 0, size);
    }
    ~Mat(){
        if(m_data) free(m_data);
    }

    float *at(int channel, int row, int col){
        assert(m_data != NULL);
        assert(row < m_rows);
        assert(col < m_cols);
        assert(channel < m_channels);

        return m_data + (channel * m_rows * m_cols) + row * m_cols + col;
    }

    int getRows() {return m_rows;}
    int getCols() {return m_cols;}
    int getChannels() {return m_channels;}
};

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
                DLTensor* out_values,
                DLTensor* out_indices,
                bool is_ascend,
                const std::function<int(int)> &get_sort_len) {
  thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
  thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
  thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

  int n_values = input->shape[input->ndim - 1];
  int n_iter = 1;
  for (int i = 0; i < input->ndim - 1; ++i) {
    n_iter *= input->shape[i];
  }
  std::cout << "thrust: is_ascend: " << is_ascend << std::endl;
  std::cout << "thrust: n_iter: " << n_iter << std::endl;
  std::cout << "thrust: starting to copy" << std::endl;
  thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
  std::cout << "thrust: copy successful" << std::endl;
  std::cout << "thrust: starting loop" << std::endl;
  std::cout << "thrust: values_ptr: " << values_ptr << std::endl;

  const char* filename = "/home/mkrzus/thrust.txt";
  ofstream fout(filename);
  Mat vec(1, 322560, 1);
  TVMArrayCopyToBytes(input, vec.m_data, 322560 * sizeof(float));
  for (int i = 0; i < 322560; i++){
    float ptr = *vec.at(0,i,0);
    fout<<ptr<<"\t";

  for (int i = 0 ; i < n_iter; ++i) {
    n_values = get_sort_len(i);
    thrust::sequence(indices_ptr, indices_ptr + n_values);
    std::cout << "thrust: sequence call sucessful for iter: " << i << std::endl;
    if (is_ascend) {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
      std::cout << "thrust: sort_by_key is_ascend successful for iter: "<< i << std::endl;
    } else {
      thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
                          thrust::greater<DataType>());
      std::cout << "thrust: sort_by_key successful for iter: "<< i << std::endl;
    }
    values_ptr += n_values;
    indices_ptr += n_values;
  }
}

I was able to run your code but couldn’t reproduce your error. Everything seems fine with C++. I’m using TVM (commit id = 7bee0eaeae0c38b7a6d3eb187c5de68880daebf6) compiled with CUDA 11, thrust and cublas (no cudnn).

1 Like

I’m running CUDA 10 and 10.2 with cudnn 7.5 on an tx2s. I’ll try updating CUDA and report back. @ymwangg are you running on x86?

Yes, I’m using an AWS p3.2xlarge instance with a Tesla V100 GPU.

Hello, were you ever able to find a solution to this issue? I may be experiencing a similar problem.

1 Like