I add an external function by lower_composit(“bmlib.conv2d”) to lower to tir, and add a pattern
def reshape_conv2d_add_bn_relu_pattern1():
"""Create a reshape_conv2d_add_bn_relu_pattern
"""
print("reshape_conv2d_add_bn_relu_pattern")
pattern1 = is_op("reshape")(wildcard())
pattern2 = is_op("nn.conv2d")(wildcard(), wildcard())
# pattern = is_op("add")(pattern1, pattern2)
pattern = is_op("add")(pattern1, pattern2)
pattern = is_op("nn.batch_norm")(pattern, wildcard(), wildcard(), wildcard(), wildcard())
pattern = is_tuple_get_item(pattern, 0)
pattern = is_op("nn.relu")(pattern)
return pattern
when I create a relay as the following and call the function bind_params_by_name
def @main(%data1: Tensor[(128), float32], %data: Tensor[(1, 128, 12, 40), float32], %weight: Tensor[(128, 128, 1, 1), float32], %gamma: Tensor[(128), float32], %beta: Tensor[(128), float32], %moving_mean: Tensor[(128), float32], %moving_var: Tensor[(128), float32]) {
%0 = reshape(%data1, newshape=[1, 128, 1, 1]);
%1 = nn.conv2d(%data, %weight, padding=[0, 0, 0, 0]);
%2 = add(%0, %1);
%3 = nn.batch_norm(%2, %gamma, %beta, %moving_mean, %moving_var);
%4 = %3.0;
nn.relu(%4)
}
after bind params...
def @main(%data: Tensor[(1, 128, 12, 40), float32]) {
%0 = reshape(meta[relay.Constant][0], newshape=[1, 128, 1, 1]);
%1 = nn.conv2d(%data, meta[relay.Constant][1], padding=[0, 0, 0, 0]);
%2 = add(%0, %1);
%3 = nn.batch_norm(%2, meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5]);
%4 = %3.0;
nn.relu(%4)
}
and then, partition the graph by the pattern, and print the mod
After partition graph...
def @main(%data: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */) -> Tensor[(1, 128, 12, 40), float32] {
@tvmgen_default_bmlib_main_0(%data) /* ty=Tensor[(1, 128, 12, 40), float32] */
}
def @tvmgen_default_bmlib_main_0(%bmlib_0_i1: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, Inline=1, global_symbol="tvmgen_default_bmlib_main_0", Compiler="bmlib", Primitive=1) -> Tensor[(1, 128, 12, 40), float32] {
%5 = fn (%FunctionVar_0_0: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_1: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, %FunctionVar_0_2: Tensor[(128, 128, 1, 1), float32] /* ty=Tensor[(128, 128, 1, 1), float32] */, %FunctionVar_0_3: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_4: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_5: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_6: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, PartitionedFromPattern="reshape_nn.conv2d_add_nn.batch_norm_TupleGetItem0_nn.relu_", Composite="bmlib.reshape_conv2d_add_bn_relu") -> Tensor[(1, 128, 12, 40), float32] {
%0 = reshape(%FunctionVar_0_0, newshape=[1, 128, 1, 1]) /* ty=Tensor[(1, 128, 1, 1), float32] */;
%1 = nn.conv2d(%FunctionVar_0_1, %FunctionVar_0_2, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 128, 12, 40), float32] */;
%2 = add(%0, %1) /* ty=Tensor[(1, 128, 12, 40), float32] */;
%3 = nn.batch_norm(%2, %FunctionVar_0_3, %FunctionVar_0_4, %FunctionVar_0_5, %FunctionVar_0_6) /* ty=(Tensor[(1, 128, 12, 40), float32], Tensor[(128), float32], Tensor[(128), float32]) */;
%4 = %3.0 /* ty=Tensor[(1, 128, 12, 40), float32] */;
nn.relu(%4) /* ty=Tensor[(1, 128, 12, 40), float32] */
} /* ty=fn (Tensor[(128), float32], Tensor[(1, 128, 12, 40), float32], Tensor[(128, 128, 1, 1), float32], Tensor[(128), float32], Tensor[(128), float32], Tensor[(128), float32], Tensor[(128), float32]) -> Tensor[(1, 128, 12, 40), float32] */;
%5(meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, %bmlib_0_i1, meta[relay.Constant][1] /* ty=Tensor[(128, 128, 1, 1), float32] */, meta[relay.Constant][2] /* ty=Tensor[(128), float32] */, meta[relay.Constant][3] /* ty=Tensor[(128), float32] */, meta[relay.Constant][4] /* ty=Tensor[(128), float32] */, meta[relay.Constant][5] /* ty=Tensor[(128), float32] */) /* ty=Tensor[(1, 128, 12, 40), float32] */
}
after I register relay.ext.bmlib, I call relay.build
tvm._ffi.register_func("relay.ext.bmlib", relay_to_runtime(tvm.target.Target("llvm")))
I have encountered an error
tvmgen_default_bmlib_main_0: num_args should be 8
I check and find in relay_to_runtime function
def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
"""Compile Relay functions to a runtime module using Tensor Expressions."""
assert isinstance(partition, relay.Function)
assert isinstance(partition.body, relay.Call)
assert isinstance(partition.body.op, relay.Function)
global_name = str(partition.attrs.global_symbol)
comp_func = partition.body.op
comp_name = comp_func.attrs["Composite"]
assert comp_name in _LOWER_MAP
assert isinstance(comp_func.body, relay.Call)
op = comp_func.body
inputs = []
for i, param in enumerate(comp_func.params):
inputs.append(
te.placeholder(
param.checked_type.shape,
name=f"input_{i}",
dtype=param.checked_type.dtype,
)
)
output = _LOWER_MAP[comp_name](op, inputs)
prim_func = te.create_prim_func(inputs + [output])
return tvm.build(prim_func, target=target, name=global_name)
inputs is the params of the function fn called by tvmgen_default_bmlib_main_0, not params of tvmgen_default_bmlib_main_0, but create_prim_func use inputs instead of params of tvmgen_default_bmlib_main_0, and caused the problem.
when I modify the pattern as the following, use is_constant() instead of wildcard(),
def reshape_conv2d_add_bn_relu_pattern2():
"""Create a reshape_conv2d_add_bn_relu_pattern
"""
print("reshape_conv2d_add_bn_relu_pattern")
pattern1 = is_op("reshape")(is_constant())
pattern2 = is_op("nn.conv2d")(wildcard(), is_constant())
pattern = is_op("add")(pattern1, pattern2)
pattern = is_op("nn.batch_norm")(pattern, is_constant(), is_constant(), is_constant(), is_constant())
pattern = is_tuple_get_item(pattern, 0)
pattern = is_op("nn.relu")(pattern)
return pattern
and partiton the graph again,
After partition graph...
def @main(%data: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */) -> Tensor[(1, 128, 12, 40), float32] {
@tvmgen_default_bmlib_main_0(%data) /* ty=Tensor[(1, 128, 12, 40), float32] */
}
def @tvmgen_default_bmlib_main_0(%bmlib_0_i0: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, Inline=1, global_symbol="tvmgen_default_bmlib_main_0", Compiler="bmlib", Primitive=1) -> Tensor[(1, 128, 12, 40), float32] {
%5 = fn (%FunctionVar_0_0: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, PartitionedFromPattern="reshape_nn.conv2d_add_nn.batch_norm_TupleGetItem0_nn.relu_", Composite="bmlib.reshape_conv2d_add_bn_relu") -> Tensor[(1, 128, 12, 40), float32] {
%0 = reshape(meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, newshape=[1, 128, 1, 1]) /* ty=Tensor[(1, 128, 1, 1), float32] */;
%1 = nn.conv2d(%FunctionVar_0_0, meta[relay.Constant][1] /* ty=Tensor[(128, 128, 1, 1), float32] */, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 128, 12, 40), float32] */;
%2 = add(%0, %1) /* ty=Tensor[(1, 128, 12, 40), float32] */;
%3 = nn.batch_norm(%2, meta[relay.Constant][2] /* ty=Tensor[(128), float32] */, meta[relay.Constant][3] /* ty=Tensor[(128), float32] */, meta[relay.Constant][4] /* ty=Tensor[(128), float32] */, meta[relay.Constant][5] /* ty=Tensor[(128), float32] */) /* ty=(Tensor[(1, 128, 12, 40), float32], Tensor[(128), float32], Tensor[(128), float32]) */;
%4 = %3.0 /* ty=Tensor[(1, 128, 12, 40), float32] */;
nn.relu(%4) /* ty=Tensor[(1, 128, 12, 40), float32] */
} /* ty=fn (Tensor[(1, 128, 12, 40), float32]) -> Tensor[(1, 128, 12, 40), float32] */;
%5(%bmlib_0_i0) /* ty=Tensor[(1, 128, 12, 40), float32] */
}
this time, the params of tvmgen_default_bmlib_main_0 and the called function fn is the same, and It is fine after relay.build() and run the exported library.
I don’t understand why implementations in relay_to_runtime, why should collect all the arguments in the subgraph, but the first subgraph partitioned by pattern seems right logically.