This RFC is a case study on unifying the lower API, which had implementations in both Python and C++. I’ll call the duplication of APIs in Python and C++ “bifurcation” or “bifurcation of the API”. I’ll conclude by analyzing how this bifurcation happened in the first place, and then present options for changing open source policy to avoid future bifurcation. You can skip ahead to “causes of the bifurcation” if you don’t want to read the details of this case.
Unifying the Lower API
The PR is here if you would like to take a look.
In the case of Lower, there were two Lower functions, one in C++ and one in Python that have different signatures. The Python was introduced first, and is much more broad than the C++ API. Notably, it accepts Union types for inputs
and args
, which is difficult to replicate in C++.
Here are the two function signatures from before the refactor, for reference:
def lower(
inputs: Union[schedule.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
simple_mode: bool = False,
) -> IRModule
IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
To preserve the behavior of the Python API, I split the C++ backend into three functions: LowerModule, LowerPrimFunc and LowerSchedule and registered these through the FFI. Then, from the python lower function, I dispatch into each of these C++ backends depending on the type passed in.
This approach is OK for a small number of types, but as you can see, if we had more types in the Union or multiple arguments with Union types, we’d have to make even more C++ functions.
The other really tricky part was that args
is an Array[Union[Buffer, Var, Tensor]]. Unfortunately, Buffer, Var and Tensor’s only common supertype is ObjectRef. So, to replicate the Python behavior in C++, we needed to let args
be an Array. But, we also need to preserve the original signature for backwards compatibility.
Thus, we end up with four function signatures in C++ to duplicate the behavior of one Python function, which is not ideal.
IRModule LowerModule(IRModule mod, bool simple_mode = false);
IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
bool simple_mode = false);
IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
Writing the 2nd version of LowerSchedule
that takes in an Array<ObjectRef>
for args
also caused some problems because the helper functions it called also needed their signatures changed.
For this case study, we’ll just look at get_binds
. In C++, there is a parallel helper function GetBinds
, but again, its signature doesn’t match the signature of the Python get_binds
. The Python version allows args
to be Array[Union[Buffer, Var, Tensor]]
, while the C++ version requires that args
is an Array<Tensor>
.
def get_binds(args, compact=False, binds=None):
The type of args
in the Python version is Array[Union[Buffer, Var, Tensor]]
.
void GetBinds(const Array<te::Tensor>& args, bool compact,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list);
To replicate the behavior of the Python get_binds
, I had to write a second version of GetBinds
in C++ that takes in args
as an Array<ObjectRef>
. Additionally, there are a few places that get_binds
is used in the codebase, so I had to register and expose GetBinds
through the FFI to ensure backwards compatibility.
The final problem I encountered was related to dead code. The C++ version is only called in one place. It calls SchedulePostProcRewriteForTensorCore
, which causes failures in some of the calls that originate from Python but not in the calls that originate from C++. Fortunately, this function is dead code, so I just removed it. However, it could have caused further problems if it were actually required in some cases but not in others.
Cause of the bifurcation
-
Python
lower
is introduced, and it relies on python duck typing -
Because code is in Python, all code is sort of like a public API, so people start using
lower
,get_binds
andform_irmodule
in ways that they shouldn’t -
Someone needs to introduce the C++ version, but it’s hard to duplicate the Python API in C++ because of duck typing. So the new C++ version doesn’t match all Python cases, and Python is not removed to preserve backwards compatibility.
-
Because Python code is the primary code used, it is updated and maintained more frequently than the C++ version, which resulted in the C++ version having a call to
SchedulePostProcRewriteForTensorCore
that should have been removed from the codebase. Additionally, that was the only call to it in the entire codebase, so we were able to remove the body ofSchedulePostProcRewriteForTensorCore
, which was over 1000 LOC.
Recommendations on OSS Policy Change
First, the current TVM “policy” is that code can be implemented in Python for the prototyping and ease of implementation, and eventually, that code will be replaced with C++ and the Python will be deleted. I am not sure if this is written down anywhere, however I have heard it from community members when I ask about why there are both Python and C++ versions of functions.
In my view, the main problem with this policy is that usually, the Python code is never removed from the code base, or is removed only after it has caused problems (like in the case of Lower).
I see two potential solutions to prevent this from happening, which have different overheads in terms of enforcement:
Option 1: Turn on static type checking in Python to make it easier to convert Python to C++, and require that Python code is removed from the code base when corresponding C++ is added. Currently we assume that the Python will be removed later, which I do not think is a good assumption to make.
Option 2: Require that all new compiler APIs (as well as most logic in general) be written in C++, and only allow Python to serve as a frontend and soft wrapper of the C++ APIs.
It would be great if people could weigh in on which solution they prefer, or provide alternate solutions. I personally prefer Option 2. Although it is a bit strict, I think it will be easier to enforce than Option 1, since enforcement of Option 1 requires reviewers to check if there is a Python API the C++ is duplicating. Additionally, turning on static type checking in Python will probably require lots of changes to existing code and take a long time.