The module interface works great for deploying the generated operator libraries in TVM. However, we are still facing a challenge on how to have a common deployment interface for machine learning models.
Our most commonly used interface is the graph runtime, whose mechanism is a bit different from low-level code loading (not as simple as module.load). In the past RFCs we already tried to argue that we should use the same module mechanism to deploy ML models. However, there are still quite a few challenges that need to be addressed:
- R1: The creation of ML model can be context dependent. For example, the user needs to be able to specify which GPU to run the graph runtime on.
- R2: We start to have multiple variations of model runtime, such as RelayVM. While it does not makes sense to force all model runtime to have the same set of APIs, it would be helpful to have a same mechanism for packaging and loading.
- R3: In advanced usecases, we want to be able to bundle multiple models into a single shared library.
It is not hard to propose a few interfaces to “solve” the above challenges. However, it is hard to agree on “the interface convention” for TVM’s ML model packaging.
As a basic principle, we directly make use of the current Module system to build a convention on top. An important thing to keep in mind is that we hope to let the users learn as little as possible.
The Raw Interface Strawman
Given that all the additional wrapping boils down to the raw module interface. We start the discussion by a strawman’s proposal using the raw module interface.
# lib is a GraphRuntimeFactoryModule
# that contains json and parameters
lib = tvm.module.load("resnet18.so")
# Call into the factory module to create a graph runtime
# Having this additional factory create step solves R1
# Note that parameters are already set
#
# The first argument is a key that helps to solve R3
#
# Alternative API 0: take model as key.
gmod = lib["runtime_create"]("resnet18", tvm.cpu(0))
# Alternative API 1: select the model then construct
gmod = lib["runtime_select"]("resnet18")(tvm.cpu(0))
# Alternative API 2: directly key constructor by model name.
gmod = lib["resnet18"](tvm.cpu(0))
set_input = gmod["set_input"]
run = gmod["run"]
get_output = gmod["get_output"]
# We do not need to set the parameters here
# as the models
set_input(data=my_data)
run()
get_output()
# Alternative: sklearn style predict function
# that might works better for VM, might help solve R2
predict = gmod["predict"]
output = predict(data=my_data, out=out_data)
A few highlights:
- H1: Instead of directly returning a GraphRuntime module when we load, we only load a factory module that contains the necessary meta-data. Then another call of the create function will create the actual graph runtime module.
- H2: The create function takes in a model name as a key, which potentially allows bundling multiple model/modules into the same shared library.
- H3: The json and parameters can be bundled into the factory module. That means the create function only have to take a context parameter. This interface brings some future benefits – e.g. we can use AOT to generate a logic that runs the graph runtime, but use the same interface.
- H4: Depending on what interface we encourage (set/run/get) vs predict. We can have a different levels of interface sharing between VM and graph runtime.
Discussions
Here are some points for discussion
- D1: do you like the factory pattern, shall we always require a model name field (and allow “default”), or shall we take the alternative API specialization approach.
- D2: set/run/get interface and predict
- set interface is useful to allow users to set parameters during runtime.
- run is useful to do fine grained benchmarking
- predict is a more high level user friendly API, note that we still want to allow destination passing style(pass out) to allow more flexibility.
- predict forces us to enable runtime tuple support in the case of multiple output, while get_output keep things simple and minimum.
- D3: runtime argument specification convention for multiple contexts in the hetrogenous env. Under the restriction that PackedFunc only takes positional argument.
- D4: Do you like the new way of packaging, or is it fine to continue use the old graph rt API.
API Wrapping
Most of the above discussions are for the raw APIs. To make life easier for our users, we can still do a minimum wrappings around the raw API.
Ask user to specify the wrapper type
The first way is to ask the users to directly construct the wrapper API using a constructor/create function.
# that contains json and parameters
lib = tvm.module.load("resnet18.so")
gmod = graph_runtime.create(lib["resnet18"], ctx=tvm.cpu(0))
gmod.set_input(data=my_data)
# sklearn style predict API
out = gmod.predict(data=my_data)
Note that the main purpose of the wrapper API is to provide clear documentation for the most common usecases. The full power of the module is always available as the raw API.
Automatically create wrapper type via type_key
Suggested by @FrozenGene . Alternatively, we can automatically create a wrapped module class using the type key. This requires us to handle rpc in a clear way(instead of using RPCModule as key, we need to get the key in the remote).
Comparing this approach with the approach above. The user does not need to specify the module wrapper and the module wrapper class is directly created during load.
It does complicate the module loading and return logic a bit (e.g. do we need to also do it for all of our modules? just like the different node variations). The inconsistency of the wrapper class and the type(e.g. an RPCModule can be wrapped as GraphRuntime) does confuses this API a bit. While in the case of node and object system, the type of the sub-class directly corresponds to the type of the object.
Would love to hear everyone’s thoughts about these two kinds of APIs.
tvm.register_module_wrapper("GraphRuntimeFactory", GraphRuntimeFactory)
tvm.register_module_wrapper("GraphRuntime", GraphRuntime)
# gmod factory is an wrapper automatically created with type key
gmodfactory = tvm.module.load("resnet18.so")
# name up to disicussion
assert isinstance(gmodfactory, GraphRuntimeFactory)
# automatically return the corresponding module wrapper by type key.
gmod = gmodfactory["resnet18"](ctx=tvm.cpu(0))
assert isinstance(gmodfactory, GraphRuntime)