[CODEGEN] How do I disable assertion code-gen for shape/type... check?

Hi All,

I have a problem. May I ask why we need these codes to check shape and type? Can I disable the check code generated on .ll to avoid performance degradation?

Example:

define dllexport i32 @fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_33(i8* noalias nocapture readonly %0, i8* noalias nocapture readonly %1, i32 %2, i8* noalias nocapture readnone %3, i8* noalias nocapture readnone %4, i8* noalias nocapture readnone %5) {
entry:
  %6 = icmp eq i32 %2, 4
  br i1 %6, label %assert_end, label %assert_fail, !prof !5

assert_fail:                                      ; preds = %entry
  tail call void @TVMAPISetLastError(i8* getelementptr inbounds ([116 x i8], [116 x i8]* @.str, i64 0, i64 0))
  ret i32 -1

assert_end:                                       ; preds = %entry
  %7 = bitcast i8* %0 to %1**
  %8 = load %1*, %1** %7, align 8
  %9 = bitcast i8* %1 to i32*
  %10 = load i32, i32* %9, align 4, !tbaa !6
  %11 = getelementptr inbounds i8, i8* %0, i64 8
  %12 = bitcast i8* %11 to %1**
  %13 = load %1*, %1** %12, align 8
  %14 = getelementptr inbounds i8, i8* %1, i64 4
  %15 = bitcast i8* %14 to i32*
  %16 = load i32, i32* %15, align 4, !tbaa !21
  %17 = getelementptr inbounds i8, i8* %0, i64 16
  %18 = bitcast i8* %17 to %1**
  %19 = load %1*, %1** %18, align 8
  %20 = getelementptr inbounds i8, i8* %1, i64 8
  %21 = bitcast i8* %20 to i32*
  %22 = load i32, i32* %21, align 4, !tbaa !23
  %23 = getelementptr inbounds i8, i8* %0, i64 24
  %24 = bitcast i8* %23 to %1**
  %25 = load %1*, %1** %24, align 8
  %26 = getelementptr inbounds i8, i8* %1, i64 12
  %27 = bitcast i8* %26 to i32*
  %28 = load i32, i32* %27, align 4, !tbaa !26
  %29 = getelementptr inbounds %1, %1* %8, i64 0, i32 0
  %30 = load i8*, i8** %29, align 8
  call void @llvm.assume(i1 true) [ "align"(i8* %30, i64 128) ]
  %31 = getelementptr inbounds %1, %1* %8, i64 0, i32 4
  %32 = load i64*, i64** %31, align 8
  %33 = getelementptr inbounds %1, %1* %8, i64 0, i32 5
  %34 = load i64*, i64** %33, align 8
  %35 = getelementptr inbounds %1, %1* %8, i64 0, i32 1, i32 1
  %36 = load i32, i32* %35, align 4
  %37 = getelementptr inbounds %1, %1* %13, i64 0, i32 0
  %38 = load i8*, i8** %37, align 8
  call void @llvm.assume(i1 true) [ "align"(i8* %38, i64 128) ]
  %39 = getelementptr inbounds %1, %1* %13, i64 0, i32 4
  %40 = load i64*, i64** %39, align 8
  %41 = getelementptr inbounds %1, %1* %13, i64 0, i32 5
  %42 = load i64*, i64** %41, align 8
  %43 = getelementptr inbounds %1, %1* %19, i64 0, i32 0
  %44 = load i8*, i8** %43, align 8
  call void @llvm.assume(i1 true) [ "align"(i8* %44, i64 128) ]
  %45 = getelementptr inbounds %1, %1* %19, i64 0, i32 4
  %46 = load i64*, i64** %45, align 8
  %47 = getelementptr inbounds %1, %1* %19, i64 0, i32 5
  %48 = load i64*, i64** %47, align 8
  %49 = getelementptr inbounds %1, %1* %25, i64 0, i32 0
  %50 = load i8*, i8** %49, align 8
  call void @llvm.assume(i1 true) [ "align"(i8* %50, i64 128) ]
  %51 = getelementptr inbounds %1, %1* %25, i64 0, i32 4
  %52 = load i64*, i64** %51, align 8
  %53 = getelementptr inbounds %1, %1* %25, i64 0, i32 5
  %54 = load i64*, i64** %53, align 8
  switch i32 %10, label %assert_fail1 [
    i32 13, label %assert_end2
    i32 7, label %assert_end2
    i32 4, label %assert_end2
    i32 3, label %assert_end2
  ]

assert_fail1:                                     ; preds = %assert_end
  tail call void @TVMAPISetLastError(i8* getelementptr inbounds ([191 x i8], [191 x i8]* @.str.1, i64 0, i64 0))
  ret i32 -1

assert_end2:                                      ; preds = %assert_end, %assert_end, %assert_end, %assert_end
  switch i32 %16, label %assert_fail3 [
    i32 13, label %assert_end4
    i32 7, label %assert_end4
    i32 4, label %assert_end4
    i32 3, label %assert_end4
  ]

assert_fail3:                                     ; preds = %assert_end2
  tail call void @TVMAPISetLastError(i8* getelementptr inbounds ([191 x i8], [191 x i8]* @.str.2, i64 0, i64 0))
  ret i32 -1

assert_end4:                                      ; preds = %assert_end2, %assert_end2, %assert_end2, %assert_end2
  switch i32 %22, label %assert_fail5 [
    i32 13, label %assert_end6
    i32 7, label %assert_end6
    i32 4, label %assert_end6
    i32 3, label %assert_end6
  ]

assert_fail5:                                     ; preds = %assert_end4
  tail call void @TVMAPISetLastError(i8* getelementptr inbounds ([191 x i8], [191 x i8]* @.str.3, i64 0, i64 0))
  ret i32 -1

assert_end6:                                      ; preds = %assert_end4, %assert_end4, %assert_end4, %assert_end4
  switch i32 %28, label %assert_fail7 [
    i32 13, label %assert_end8
    i32 7, label %assert_end8
    i32 4, label %assert_end8
    i32 3, label %assert_end8
  ]

assert_fail7:                                     ; preds = %assert_end6
  tail call void @TVMAPISetLastError(i8* getelementptr inbounds ([191 x i8], [191 x i8]* @.str.4, i64 0, i64 0))
  ret i32 -1

assert_end8:                                      ; preds = %assert_end6, %assert_end6, %assert_end6, %assert_end6
  %55 = getelementptr inbounds %1, %1* %8, i64 0, i32 2
  %56 = load i32, i32* %55, align 4
  %57 = icmp eq i32 %56, 4
  br i1 %57, label %assert_end12, label %assert_fail9, !prof !5

  ...

  %262 = tail call fastcc i32 @fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_33_compute_(i8* %30, i8* %38, i8* %50, i8* %44, i32 %36)  

Can I disable assertions just to call xxx_compute_?

Thank you so much.

with tvm.transform.PassContext(config={"tir.disable_assert": True}):
        m = tvm.build(mod, [x, y, z], target="llvm")

This can turn off assert :grinning:

1 Like