I am looking to create a simple custom c-codegen to interact with an external library and I need some guidance.
I started to use the RelayToTIR + TIRToRuntime method in order to take advantage of unified static memory planning (as described in the “Additional Target Hooks” RFC) and I’m having trouble creating a custom lowering method in RelayToTIR
I have created a partition function which includes the MergeCompilerRegions pass and the OutlineCompilerFunctionsWithExistingGlobalSymbols pass with the target “toy_cg_thflow” and the composite functions “dense_bias” and “dense_bias_relu”
Here is an example global variable at the input to my RelayToTIR pass:
def @tvmgen_default_toy_cg_thflow_main_0(%toy_cg_thflow_0_i0: Tensor[(64, 800), float32] /* ty=Tensor[(64, 800), float32] */, %toy_cg_thflow_0_i1: Tensor[(500, 800), float32] /* ty=Tensor[(500, 800), float32] */, %toy_cg_thflow_0_i2: Tensor[(500), float32] /* ty=Tensor[(500), float32] */, %toy_cg_thflow_0_i3: Tensor[(10, 500), float32] /* ty=Tensor[(10, 500), float32] */, %toy_cg_thflow_0_i4: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, Compiler="toy_cg_thflow", Primitive=1, Inline=1, global_symbol="tvmgen_default_toy_cg_thflow_main_0") -> Tensor[(64, 10), float32] {
%13 = fn (%FunctionVar_0_01: Tensor[(64, 800), float32] /* ty=Tensor[(64, 800), float32] */, %FunctionVar_0_11: Tensor[(500, 800), float32] /* ty=Tensor[(500, 800), float32] */, %FunctionVar_0_21: Tensor[(500), float32] /* ty=Tensor[(500), float32] */, PartitionedFromPattern="nn.dense_nn.bias_add_nn.relu_", Composite="toy_cg_thflow.dense_bias_relu") -> Tensor[(64, 500), float32] {
%11 = nn.dense(%FunctionVar_0_01, %FunctionVar_0_11, units=None) /* ty=Tensor[(64, 500), float32] span=aten::linear_0:0:0 */;
%12 = nn.bias_add(%11, %FunctionVar_0_21, axis=-1) /* ty=Tensor[(64, 500), float32] span=aten::linear_0:0:0 */;
nn.relu(%12) /* ty=Tensor[(64, 500), float32] span=aten::relu_2:0:0 */
} /* ty=fn (Tensor[(64, 800), float32], Tensor[(500, 800), float32], Tensor[(500), float32]) -> Tensor[(64, 500), float32] */;
%14 = %13(%toy_cg_thflow_0_i0, %toy_cg_thflow_0_i1, %toy_cg_thflow_0_i2) /* ty=Tensor[(64, 500), float32] */;
%15 = fn (%FunctionVar_0_0: Tensor[(64, 500), float32] /* ty=Tensor[(64, 500), float32] */, %FunctionVar_0_1: Tensor[(10, 500), float32] /* ty=Tensor[(10, 500), float32] */, %FunctionVar_0_2: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, PartitionedFromPattern="nn.dense_nn.bias_add_", Composite="toy_cg_thflow.dense_bias") -> Tensor[(64, 10), float32] {
%10 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=None) /* ty=Tensor[(64, 10), float32] span=aten::linear_1:0:0 */;
nn.bias_add(%10, %FunctionVar_0_2, axis=-1) /* ty=Tensor[(64, 10), float32] span=aten::linear_1:0:0 */
} /* ty=fn (Tensor[(64, 500), float32], Tensor[(10, 500), float32], Tensor[(10), float32]) -> Tensor[(64, 10), float32] */;
%15(%14, %toy_cg_thflow_0_i3, %toy_cg_thflow_0_i4) /* ty=Tensor[(64, 10), float32] */
}
I am now attempting to create a custom lowering method that will simply allow me to call an external function for each composite node, passing in the inputs, weights, biases, and output buffer
My naive way of attempting to do this is a pass like this in python:
class LowerFusedNodesSimple(ExprMutator):
# Convert Call node to tir.call_extern
def visit_call(self, call: Call):
fn = call.op
if(isinstance(fn, Function) and "Composite" in fn.attrs):
if(fn.attrs["Composite"].split(".")[0] == 'toy_cg_thflow'):
call_ext_args = []
for arg in call.args:
self.visit(arg)
call_ext_args.append(self.memo_map[arg])
return tir.call_extern(fn.checked_type.ret_type.dtype, fn.attrs["Composite"].split(".")[1], *call_ext_args)
return None
# Convert Function node to PrimFunc
def visit_function(self, fn: Function):
# Recursively visit params
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
if new_params == list(fn.params) and new_body == fn.body:
return fn
if("global_symbol" in fn.attrs and "toy_cg_thflow" in fn.attrs["global_symbol"]):
buffer_dict = {self.memo_map[arg]:tir.decl_buffer(shape = arg.type_annotation.shape, dtype = arg.type_annotation.dtype, name = arg.name_hint) for arg in fn.params}
attrs = {"global_symbol": fn.attrs["global_symbol"],
"target": tvm.target.Target("toy_cg_thflow"),
"tir.noalias": True}
dict_attrs = tvm.ir.make_node("DictAttrs", **attrs)
return tir.PrimFunc(params = buffer_dict.keys(), body = Evaluate(self.memo_map[fn.body]), ret_type = fn.checked_type.ret_type, buffer_map=buffer_dict, attrs=dict_attrs)
# Convert relay Var to tir Var
def visit_var(self, var: Var):
return tir.Var(var.name_hint, var.type_annotation.dtype)
# Convert relay const to TIR const
def visit_constant(self, const):
return tir.FloatImm(const)
But this results in nested function calls without intermediate buffers
int32_t tvmgen_default_toy_cg_thflow_main_0(float* toy_cg_thflow_0_i0, float* toy_cg_thflow_0_i1, float* toy_cg_thflow_0_i2, float* toy_cg_thflow_0_i3, float* toy_cg_thflow_0_i4, uint8_t* global_const_workspace_16_var, uint8_t* global_workspace_17_var) {
dense_bias(dense_bias_relu(toy_cg_thflow_0_i0, toy_cg_thflow_0_i1, toy_cg_thflow_0_i2), toy_cg_thflow_0_i3, toy_cg_thflow_0_i4);
return 0;
}
Does anyone know of an easy way to lower composite nodes to custom external function calls with intermediate buffers similar to how default codegen works? It seems like this should be a simple use-case of BYOC