Problem:
TensorFlow is one of the most popular machine learning libraries and most developers are used to train/inference models with TensorFlow/TensorFlow Serving. TVM is the flexible compiler to run computation efficiently in different devices. Although TensorFlow has implemented some efficient GPU operators, developers can benifit from TVM to get more than 10 times speedup and FPGA support. But TensorFlow and TVM have two different code stacks and runtime APIs to use.
There are two ways to integrated TVM with TensorFlow. The first one is tensorflow-to-tvm which has been support by relay importer. Most TensorFlow operators can be “translated” to TVM operators which is useful if want to run the TVM stack with the model structure from other frameworks.
The second one is tvm-to-tensorflow. This requires to embed TVM operators in TensorFlow graph so that we can use TensorFlow session to run preset operators and TVM-optimized operators. This is really helpful if we want to use TVM to optimize part of the computation graph while developers can use TensorFlow Python API to describe the model and use TensorFlow Serving for inference. Embedding TVM in TensorFlow requires the minimal cost to use TVM optimiztion on existing models and extend TensorFlow functionalities such as FPGA support.
This RFC describes how we design to support tvm-to-tensorflow with TensorFlow custom op API and the detail of implementation.
Considerations:
We want to support the complete TVM stack. TVM provides efficient C++ API, easy-to-use Python API to define the kernel scheduling and AutoTVM to search the optimal parameters. We want developers to use the existing tools to define TVM operators and embed the output files instead of re-implementing the same logic in TensorFlow.
We want no develop effort for end users. TensorFlow provides C++ API to define custom op and Python API to load the op with dynamic libraries. We don’t want users to write C++ and Python code for wrapping TVM op as TensorFlow custom op by themselves. We can implement the general C++ TVM runime operator and Python class so that users can use the TVM op in TensorFlow without implementing any TensorFlow custom op.
We want less code change for usage. Sometimes code change is inevitable because we need to specify which op to be replaced. We can use TensorFlow graph edtor API to replace the originl TensorFlow op with TVM op. Since TVM has support some TensorFlow op with the same functionality and may have better performance, we can design the tools like TF-TRT to automatically moditify TensorFlow SavedModel to the optimized one with TVM op. This could be done once we can embed TVM op in TensorFlow graph and replace the TensorFlow op with TVM op.
Proposal:
Now we could not merge the code in TVM codebase but the API should be similar. User can use the TVM stack to build the op and export as dynamic library files. Here is the example code to export the TVM dynamic libries for CPU and GPU.
# CPU
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op)
fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
dylib_path = os.path.join(base_path, "test_addone_dll.so")
fadd_dylib.export_library(dylib_path)
# GPU
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(s, [A, B], "cuda", name="addone")
dylib_path = os.path.join(base_path, "test_addone_cuda_dll.so")
fadd_dylib.export_library(dylib_path)
Then we can use the pre-built TensorFlow custom op for TVM runtime with Python wrapper. This op is like any other TensorFlow op which can be used in TensorFlow graph and session. There are two options to wrap this Python API.
Option one is to use TensorFlow custom op Python API directly and initialize the op with library file path and function name.
import tensorflow as tf
from tvm.contrib import tf_runtime
with tf.Session() as sess:
a = tf.constant([10.1, 20.0, 11.2, -30.3])
b = tf_runtime(a, lib_path="tvm_addone_dll.so", function_name="addone")
print(sess.run(b))
Option two is to extend the GraphModule
in TVM Python API or wrap with new tf_runtime.Module
class.
import tensorflow as tf
from tvm.contrib import graph_runtime
mod = graph_runtime.create(graph, lib, ctx)
addone = mod["addone"]
with tf.Session() as sess:
a = tf.constant([10.1, 20.0, 11.2, -30.3])
b = addone(a)
print(sess.run(b))
However, we have to call underlay TensorFlow API to load the custom op and return the tensor object. The GraphModule
seems to be the Python bridge for C++ and run TVM op directly instead of being the standard TensorFlow op to run by TensorFlow session. It is okay to wrap TensorFlow custom op with any Python class and try to match the same usage of other TensorFlow op.
The TensorFlow custom op for TVM runtime can be implemented by combining TVM Runtime C++ API and TensorFlow custom op C++ API. We have the implementation and examples in https://github.com/tobegit3hub/tftvm which can be moved to tvm.contrib
once the API is determined. Here is the code the register the TensorFlow custom op and it requires lib_path
and function_name
to load the TVM dynamic libraries. Moreover, TVM Runtime API requires to know the dtype and shape of the input tensors, these messages can be passed either by TensorFlow op attr or load from TVM dynamic libraries.
REGISTER_OP("TvmRuntime")
.Attr("lib_path: string")
.Attr("function_name: string")
.Input("tvm_input: float")
.Output("tvm_output: float");
Then we can implement the CPU and GPU kernel with TVM Runtime C++ API. We can load dynamic libraries with attribute parameters when initializing the op. For each op process, overwrite the Compute
method to use TVM Runtime API for computation and read/write data as TensorFlow tensor.
void Compute(OpKernelContext* context) override {
int device_type = TvmRuntimeOpTrait<DEVICE_TYPE>::device_type;
int device_id = TvmRuntimeOpTrait<DEVICE_TYPE>::device_id(context);
int64_t shape[1] = {10};
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
device_type, device_id, &x);
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
device_type, device_id, &y);
// Get input tensor
auto input = input_tensor.flat<float>();
x->data = const_cast<float*>(input.data());
const int input_size = input.size();
// TVM run
tvm_func(x, y);
// To output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<float>();
memcpy(output_flat.data(), y->data, input_size*4);
//cudaMemcpy(output_flat.data(), y->data, input_size*4, cudaMemcpyDeviceToDevice);
}
Notice that the TensorFlow custom op can be built just once and be used to load different TVM operaters. Developers need TensorFlow environment to build this custom op with g++
or bazel
.
Finally, we can use TensorFlow graph editor API to replace original TensorFlow op with TVM-optimized custom op and automatically TensorFlow SavedModel convertors. These are less important than the previous TensorFlow custom op and we may discuss that once the fundamental functionalities are ready.
Related discussion is in Add TensorFlow custom op and run tvm in TensorFlow .