Introduction
TVM Object system provides a convenient and decent way to share objects between backend (C++) and frontend (Python/Java/Rust/etc.). For example, one can construct a variable in Python and pass it to functions written in C++, and vice versa.
However, adding one object node into TVM stack requires manually adding lines of code to different places in both Python and C++. For example, here’s how tvm::tir::IntImm
is implemented and registered,
- Definition for Node and its Reference: https://github.com/apache/incubator-tvm/blob/master/include/tvm/ir/expr.h#L228-L270
- Implement functionality: https://github.com/apache/incubator-tvm/blob/master/src/ir/expr.cc#L58-L68
- Node registry in C++: https://github.com/apache/incubator-tvm/blob/master/src/ir/expr.cc#L70-L84
- Node registry in Frontend (Python): https://github.com/apache/incubator-tvm/blob/master/python/tvm/tir/expr.py#L275-L290
This RFC advocates the approach to generate C++ implement directly from Python class definition and registry. Moreover, as we still allow users to write C++ code manually in order to bring in more complex features, the object transpiler will provide basic validation for these manually written C++ code.
Here is an example of how an object can be described in Python and how the generated C++ code looks like:
@declare
class BaseExprNode(Object):
"""
Base type of all the expression.
See Also
--------
BaseExpr
"""
type_key = "BaseExpr"
default_visit_attrs = False
default_sequal_reduce = False
default_shash_reduce = False
@declare
class IntImmNode(PrimExprNode):
"""
Constant integer literals in the program.
See Also
--------
IntImm
Attributes
----------
value
The internal value.
"""
type_key = "IntImm"
value: ty.int64_t
/*!
* \brief Base type of all the expressions.
* \sa Expr
*/
class BaseExprNode : public Object {
public:
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
/*!
* \brief Managed reference to BaseExprNode.
* \sa BaseExprNode
*/
class BaseExpr : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
};
/*!
* \brief Constant integer literals in the program.
*/
class IntImmNode : public PrimExprNode {
public:
/*! \brief The internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
void SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value)
}
void SHashReduce(SHashReducer hash_reducer) const {
hash_reducer(dtype);
hash_reducer(value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_BASE_OBJECT_INFO(IntImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to IntImmNode.
*
* \sa IntImmNode
*/
class IntImm : public PrimExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};
We name it as TVM Object Schema DSL, or tschema. In summary, tschema will bring several benefits for the TVM architecture:
- Reduce boilerplate code;
- Verify to avoid missing some definition like
TVM_REGISTER(...)
; - Enable deployment on all kinds environment even without C++;
- Fields like
type_child_slots
can be automatically generate for optimizing; - Allow users to define Objects in Python, build and export them to a .o/.so file;
- Have more type information during runtime, enable some optimizations in TIR compilation;
High-level Object Compilation Pipeline
- Define TVM Object in Python. This object definition Python file is in a seperate directory (which will not be a part of PYTHONPATH) other than python/tvm/
- Run python class parser to generate related .h, .cc files. This step can be triggered manually or via cmake. The generated files will be checked into the code base so that code completion tools can locate them.
- Compile TVM using cmake as usual.
Notice that the second step happens during (or before) compiling TVM itself. We provide a standalone tool to parse the Python code.
TSchema DSL
Take IntImm
as an example, the
@declare
class IntImmNode(PrimExprNode):
"""
Constant integer literals in the program.
See Also
--------
IntImm
Attributes
----------
value
The internal value.
"""
type_key = "IntImm"
value: ty.int64_t
There are several things require to be parsed,
- Object name. In the above example it is
IntImmNode
, thereforeclass IntImmNode
(extends Object) will be generated. - Type key. In the above example it is
IntImm
, thereforeclass IntImm
(extends ObjectRef) will be generated. - Parent class. In the above example it is
PrimExprNode
- Member variables. In the above example they are,
-
value
and its type annotationint64_t
-
- The constructor arguments in C++ will be generated as the same order of the arguments in Python class definition.
- We also will generate default
VisitAttrs
,SEqualReduce
,SHashReduce
methods unless user specifydefault_visit_attrs
asFalse
.
Inplace C++ Source File Modification
As we mentioned before, there are cases where users need to implement complex functions manually. To leverage the convenience of Python declaration and automatic code generation in such cases, we provide an option to modify the C++ source file in-place, and give users the control to specify which part of the file can be modified.
We provide comment parser for .h and .cc file, in which users can wrap the auto-generated section by comments, e.g.,
// tschema: ObjectName
The lines between tschema: ObjectName and tschema: end
will be manipulated by tschema
// tschema: custom-begin
User can also mark sections which should be left unchanged by objgen
This section will be inserted at the end of the class definition,
right before the close brace
// tschema: custom-end
// tschema: end
Here is also an example for it:
Before generation
// tschema: GlobalVarNode
// tschema: custom-begin
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other)
}
bool SHashReduce(SHashReducer hash_reducer) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}
// tschema: custom-end
// tschema: end
TSchema Definition
@declare
class GlobalVarNode(RelayExprNode):
"""
Global variable that lives in the top-level module.
A GlobalVar only refers to function definitions.
This is used to enable recursive calls between function.
See Also
--------
GlobalVarNode
Attributes
----------
name_hint
The name of the variable, this only acts as a hint.
"""
type_key = "GlobalVar"
default_sequal_reduce = False
default_shash_reduce = False
name_hint: ty.String
Generated Code
// tschema: GlobalVarNode
class GlobalVarNode : public RelayExprNode {
public:
String name_hint;
void VisitAttrs(AttrVisitor* v) {
v->Visit("span", &span);
v->Visit("checked_type_", &checked_type_);
v->Visit("name_hint", &name_hint);
}
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_BASE_OBJECT_INFO(GlobalVarNode, RelayExprNode);
// tschema: custom-begin
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other)
}
bool SHashReduce(SHashReducer hash_reducer) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}
// tschema: custom-end
};
// tschema: end
@tqchen @yzhliu @jwfromm @jroesch @junrushao , also thanks Yizhi for the initial idea and RFC writing.