Logic Errors in PrimFunc without strict Check

Many time, when I use relax.build() API to deploy a AI Model, Error occur in tvm::codegen::CodeGenLLVM::Verify() after tvm::codegen::CodeGenLLVM::Finish(). there are generally two type errors in LLVM Side: a):operator are not of the same type! E.G:

Both operands to ICmp instruction are not of the same type!
%109 = icmp eq i32 %1, i64 1
Both operands to ICmp instruction are not of the same type!
 %119 = icmp eq i32 %1, i64 1

b):Found return instr that returns non-void in Function of void return type! E.G:

Found return instr that returns non-void in Function of void return type!
ret i32 %12
voidFound return instr that returns non-void in Function of void return type!
 ret i32 %74
 voidFound return instr that returns non-void in Function of void return type!
 ret i32 %9
void

I have commit this kind issue, see Int32/Int64 issue when codegen into llvm::Function?

in my limited knowledge, we could transform the Primfunc to Legal Form before we download it to llvm::Funtion. Yes, I had found some kind Pass in tir/transform, such as tir.transform.NarrowDataType(), tir.transform.ForceNarrowIndexToInt32(), etc, but which have’t nicely solved my diffculty.

Is there any idea to the above issues?

Would be great to provide a reproducible script :slight_smile:

Okay, As this is a AMD-Arch home-make GPU,target implement code has’t opened, I’ll try to give a UT on given targets of TVM

I’m trying to reproduce the UT in Nvidia-A10 GPU, however, in latest tvm/unity branch, the relax front hav’t support the NonMaxSuppression Op. I forgot I have add the NonMaxSuppression to relax front in our code repo. The follow script try to loading the NMS Op by reusing the relay frontend and transforming with relay_translator(),still the similar error occur in BufferVar Dtype.

        '''
    pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple/
    pip3 install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple/
    '''
    import tvm
    from tvm import relax
    from tvm.script import ir as I
    from tvm.script import relax as R
    from tvm.script import tir as T 
    # from tvm.relax.frontend.onnx import from_onnx
    from tvm.relay.frontend.onnx import from_onnx
    from tvm.relax.testing import relay_translator
    from tvm.relax.transform import LegalizeOps
    from tvm.tir.transform import DefaultGPUSchedule
    import numpy as np 
    import onnx 

    import onnx
    from onnx import helper
    from onnx import TensorProto

    def gen_nms_model():
        boxes = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [8, 25200, 4])
        scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT, [8, 1, 25200])
        max_output_boxes_per_class = helper.make_tensor_value_info('max_output_boxes_per_class', TensorProto.INT64, [1])
        iou_threshold = helper.make_tensor_value_info('iou_threshold', TensorProto.FLOAT, [1])
        score_threshold = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT, [1])

        selected_indices = helper.make_tensor_value_info('selected_indices', TensorProto.INT64, [None, 3])

        nms_node = helper.make_node(
            'NonMaxSuppression',
            inputs=['boxes', 'scores', 'max_output_boxes_per_class', 'iou_threshold', 'score_threshold'],
            outputs=['selected_indices']
        )

        graph_def = helper.make_graph(
            [nms_node],
            'nms_graph',
            [boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold],
            [selected_indices]
        )

        model_def = helper.make_model(graph_def, producer_name='onnx-nms')

        onnx.checker.check_model(model_def)

        return model_def


    def foo():
        target = tvm.target.Target("nvidia/nvidia-a10")
        device = tvm.cuda(0)
        # target = tvm.target.Target("llvm")
        # device = tvm.cpu(0)

        onnx_mod = onnx.load("nms_model.onnx")
        relay_mod = from_onnx(onnx_mod)

        with open("nms.py", "w") as f:
            print(relay_mod, file=f)

        relax_mod = relay_translator.from_relay(relay_mod[0]["main"], target)

        seq = tvm.transform.Sequential([
            LegalizeOps(),
            DefaultGPUSchedule(),
        ])
        with  target:
            relax_mod = seq(relax_mod)

        ex = relax.build(relax_mod, target=target)

        vm = relax.VirtualMachine(ex, device)
        
        print("done.")
        


    nms_model = gen_nms_model()
    onnx.save(nms_model, "nms_model.onnx")
    foo()

the NonMaxSuppression Op is quite important in CV models, due to my limited ability, there are some bugs in the added NMS OpConverter in relax frontend. the following is my implemented version:

    class NonMaxSuppression(OnnxOpConverter):
    """Operator converter for NonMaxSuppression."""

    @classmethod
    def _impl_v10(cls, bb, inputs, attr, params):
        # Get parameter values
        boxes = inputs[0]
        scores = inputs[1]
        max_output_boxes_per_class = inputs[2]
        iou_threshold = inputs[3]
        score_threshold = inputs[4]
        boxes_dtype = boxes.checked_type.dtype

        if attr.get("center_point_box", 0) != 0:
            xc, yc, w, h = bb.normalize(relax.op.split(boxes, 4, axis=2))
            half_w = w / relax.expr.const(2.0, boxes_dtype)
            half_h = h / relax.expr.const(2.0, boxes_dtype)
            x1 = xc - half_w
            x2 = xc + half_w
            y1 = yc - half_h
            y2 = yc + half_h
            boxes = bb.normalize(relax.op.concat([y1, x1, y2, x2], axis=2))

        if iou_threshold is None:
            iou_threshold = relax.expr.const(0.0, dtype="float32")
        if score_threshold is None:
            score_threshold = relax.expr.const(0.0, dtype="float32")

        def conditionally_squeeze_scalar(x):
            rank = len(x.struct_info.shape)
            assert rank <= 1, "nms thresholds must be scalars"
            if rank == 1:
                return relax.op.squeeze(x, [0])
            return x

        max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
        iou_threshold = conditionally_squeeze_scalar(iou_threshold)
        score_threshold = conditionally_squeeze_scalar(score_threshold)
        
        # fix: ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
        max_output_boxes_per_class = bb.normalize(max_output_boxes_per_class)
        iou_threshold = bb.normalize(iou_threshold)
        score_threshold = bb.normalize(score_threshold)
        
        nms_out = bb.emit_te(topi.vision.all_class_non_max_suppression, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)

        begin = relax.const([0, 0], dtype="int64")
        end = bb.normalize(relax.op.concat([nms_out[1], relax.const([3], dtype="int64")]))
        strides = relax.const([1, 1], dtype="int64")

        return bb.normalize(relax.op.dynamic_strided_slice(nms_out[0], begin=begin, end=end, strides=strides))