This RFC outlines the steps we would like to take to bring an AOT compiler for deploying models. Some of the proposed approaches depends on the Unified IR enhancements, but we feel that it is good to discuss the technical choices first with the community so we can be prepared and drive the related designs to the right direction.
Motivation
For domains where functional safety requirements are important. Some users would like an AOT compilation mode of TVM which does not depend on the graph runtime, and potentially allows a user to ship smaller binaries in a consistent and minimal dependency way for eg embedded devices.
Here AOT means we not only compile the operators, but also compile the graph interpretation part of the execution.
This is a draft RFC to outline the key design decisions that are relevant in an AOT compiler design
Runtime API
We will build a consistent, runtime. Module based API as in [DISCUSS] Module based Model Runtime Interface
This same API can be used for both AOT and graph runtime in the same way.
The only technical challenge for the runtime is that we will need to have an alternative minimum version that only wraps the C API(without c++11 due to lack of support typically found in micro controllers). Such runtime can be implemented via languages like C, rust, or generated through codegen.
Example Raw C API usage:
void* lib, fset_input, frun, fget_output;
TVMModLoadFromFile(“resnet.so”, “so”, &lib);
TVMModGetFunction(“setinput”, lib, &fset_input);
TVMModGetFunction(“setoutput”, lib, &fset_output);
TVMModGetFunction(“run”, lib, &frun);
// call into Packed Functions
Graph AOT vs Fully Featured Relay AOT
As a starting point, we could start with a Graph AOT, to support a limited subset of relay program. One goal would be eventually support fully featured Relay AOT, which will bring dependencies on dynamic memory allocator but also support advanced features like control flow.
Runtime State Data Structure
The runtime should still depend on a minimum set of basic primitives, in particular, ways to allocate an array of DLTensor, and setup the memory space in a way that can be accessed through generated code. This means we need a data structure(let us name it GraphRuntimeState) that holds Array of DLTensor and PackedFunc. This data structure need to be accessible from the generated code, which means it is best to implement it as a C ABI compatible way.
One way to unify this data structure with the runtime system is to make use of the Object protocol, so that the GraphRuntimeState can be accessed from any of the languages in the frontend compatible way.
Possible Technical Path
P0: Relay -> C/C++
The simplest approach based on the Relay AOT POC, that directly transpiles a relay program into C API calls into the TVM runtime. The drawback of this approach is that it will bring dependency to the C API. We could also create an LLVM backend, however, see P1.
P1: Relay -> TIR::Function -> runtime.Module
As an alternative approach, we can first lower the relay function into a TIR::Function that corresponds to the low-level actions taken by the runtime. Then we can call the existing code generator to lower the TIR::Function into the final runtime.
This is a more desirable approach in the world of unified IR. Because we don’t have to build a specific code generator backend for relay, but can directly reuse the TIR’s code generator.
Most of the key technical challenges in this path depend on making TIR::Function to be expressive enough to represent the low-level operations of a graph executor. The code below shows a mock up text representation of what the low-level IR could look-like. In order to be able to lower this IR. we will need to be able to handle object(GraphRuntimeState and Array) in the TIR. But once we are able to do that, we can have a bring in flexible implementations, including support additional data structures(via Object).
# mocked up syntax to show the corresponding low-level IR
def @graph_init():
%arr = @Array.Create()
@NDArray.push_back(%arr, @NDArray.empty([%const_shape0]))
@NDArray.push_back(%arr, @NDArray.empty([%const_shape1]))
def @graph_run():
%ctx = @context.GetGraphRuntimeState()
@call_packed("layer0", %ctx.data[0], %ctx.data[1])
@call_packed("layer1", %ctx.data[1], %ctx.data[2])
@call_packed("layer2", %ctx.data[2], %ctx.data[3])