Implementing AOT in TVM

To further illustrate what I meant by the impact of compiler optimizations, i ran the following quick experiment:

// test.cc
#include <tvm/runtime/c_runtime_api.h>                                                                                                                                                        
                                                 
// implement the function using PackedCFunc calling convention                                                                                                                                             
inline int PackedCFunc(TVMValue* args, int* type_codes, int num_args,                                                                                                                         
                       TVMValue* out_ret_value, int* out_ret_tcode,                                                                                                                           
                       void* resource_handle) {                                                                                                                                               
  int v0 = args[0].v_int64;                                                                                                                                                                   
  void* ptr = args[1].v_handle;                                                                                                                                                               
  out_ret_tcode[0] = kTVMArgInt;                                                                                                                                                              
  out_ret_value[0].v_int64 = v0 + ((int*)ptr)[0];                                                                                                                                             
  return 0;                                                                                                                                                                                   
}                                                                                                                                                                                             
                                                                                                                                                                                              
// return x + ptr[0];                                                                                                                                                                         
extern "C" int AddViaPackedCFunc(int x, int* ptr) {                                                                                                                                           
  TVMValue args[2];                                                                                                                                                                           
  int type_codes[2];                                                                                                                                                                          
  TVMValue out_ret_value;                                                                                                                                                                     
  int out_ret_tcode;                                                                                                                                                                          
                                                                                                                                                                                              
  args[0].v_int64 = x;                                                                                                                                                                        
  args[1].v_handle = ptr;                                                                                                                                                                     
  type_codes[0] = kTVMArgInt;                                                                                                                                                                 
  type_codes[1] = kTVMOpaqueHandle;                                                                                                                                                           
  PackedCFunc(args, type_codes, 2, &out_ret_value, &out_ret_tcode, nullptr);                                                                                                                  
  return out_ret_value.v_int64;                                                                                                                                                               
}                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                       

Result of Clang

Run command

clang-10 -O2 -S -emit-llvm -I /path/to/tvm/3rdparty/dlpack/include -I /path/to/tvm/include -o test.ll test.cc   
cat test.ll

Gives the following code(meta data removed)

; Function Attrs: nounwind readonly uwtable
define dso_local i32 @AddViaPackedCFunc(i32 %0, i32* %1) local_unnamed_addr #0 {
  %3 = load i32, i32* %1, align 4, !tbaa !2
  %4 = add nsw i32 %3, %0
  ret i32 %4
}

Result of GCC

gcc -O2 -S  -I /path/to/tvm/3rdparty/dlpack/include -I /path/to/tvm/include -o test.s test.cc
cat test.s
	.file	"test.cc"
	.text
	.p2align 4,,15
	.globl	AddViaPackedCFunc
	.type	AddViaPackedCFunc, @function
AddViaPackedCFunc:
.LFB1:
	.cfi_startproc
	movl	(%rsi), %eax
	addl	%edi, %eax
	ret
	.cfi_endproc
.LFE1:
	.size	AddViaPackedCFunc, .-AddViaPackedCFunc
	.ident	"GCC: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0"
	.section	.note.GNU-stack,"",@progbits

Discussions

As we can see this is esssentially equivalent to the direct C calling

int Add(int x, int *ptr) {
  return x + ptr[0]
}

To understand what is happening under the hood, the following optimization will leads to this result:

  • Inlining that inlines the call
  • Mem2reg that promote the head store/load to register operations
  • Deadcode elimination that eliminates the unused type id
  • Reasoning around in32 passing via int64, cast<int32>(cast<int64>(x)) = x when x is i32

Compiling Code with TypeId Checking

The compiler can even do smarter things, when we have code that already includes the type code check. We can try out the same experiment on the following code, we will find that the result is the same as the direct C calling without any type id checking.This is because compiler can inline, constant fold and then dead-code eliminate the type id checking part.

#include <cstdio>                                                                                                                                                                             
#include <tvm/runtime/c_runtime_api.h>                                                                                                                                                        
                                                                                                                                                                                              
inline int PackedCFunc(TVMValue* args, int* type_codes, int num_args,                                                                                                                         
                       TVMValue* out_ret_value, int* out_ret_tcode,                                                                                                                           
                       void* resource_handle) {                                                                                                                                               
  int v0 = args[0].v_int64;                                                                                                                                                                   
  void* ptr = args[1].v_handle;                                                                                                                                                               
  // error check that can be dead-code eliminated                                                                                                                                             
  if (type_codes[0] != kTVMArgInt) {                                                                                                                                                          
    return -1;                                                                                                                                                                                
  }                                                                                                                                                                                           
  if (type_codes[1] != kTVMOpaqueHandle) {                                                                                                                                                    
    return -1;                                                                                                                                                                                
  }                                                                                                                                                                                           
                                                                                                                                                                                              
  out_ret_tcode[0] = kTVMArgInt;                                                                                                                                                              
  out_ret_value[0].v_int64 = v0 + ((int*)ptr)[0];                                                                                                                                             
  return 0;                                                                                                                                                                                   
}                                                                                                                                                                                             
                                                                                                                                                                                              
// return x + ptr[0];                                                                                                                                                                         
extern "C" int AddViaPackedCFunc(int x, int* ptr) {                                                                                                                                           
  TVMValue args[2];                                                                                                                                                                           
  int type_codes[2];                                                                                                                                                                          
  TVMValue out_ret_value;                                                                                                                                                                     
  int out_ret_tcode;                                                                                                                                                                          
                                                                                                                                                                                              
  args[0].v_int64 = x;                                                                                                                                                                        
  args[1].v_handle = ptr;                                                                                                                                                                     
  type_codes[0] = kTVMArgInt;                                                                                                                                                                 
  type_codes[1] = kTVMOpaqueHandle;                                                                                                                                                           
                                                                                                                                                                                              
  // note: check can be dead-code eliminated                                                                                                                                                  
  if (PackedCFunc(args, type_codes, 2, &out_ret_value, &out_ret_tcode, nullptr) != 0) {                                                                                                       
    printf("error\n");                                                                                                                                                                        
  }                                                                                                                                                                                           
  if (out_ret_tcode != kTVMArgInt) {                                                                                                                                                          
    printf("error\n");                                                                                                                                                                        
  }                                                                                                                                                                                           
  return out_ret_value.v_int64;                                                                                                                                                               
}                                                                                                                                                                                             
                            
1 Like