Byoc: different pattern partitioned the same subgraph get different results, and relay_to_runtime's behavior is strange

I add an external function by lower_composit(“bmlib.conv2d”) to lower to tir, and add a pattern

def reshape_conv2d_add_bn_relu_pattern1():
        """Create a reshape_conv2d_add_bn_relu_pattern
        """
        print("reshape_conv2d_add_bn_relu_pattern")
        pattern1 = is_op("reshape")(wildcard())
        pattern2 = is_op("nn.conv2d")(wildcard(), wildcard())
        # pattern = is_op("add")(pattern1, pattern2)
        pattern = is_op("add")(pattern1, pattern2)
        pattern = is_op("nn.batch_norm")(pattern, wildcard(), wildcard(), wildcard(), wildcard())
        pattern = is_tuple_get_item(pattern, 0)
        pattern = is_op("nn.relu")(pattern)
        return pattern

when I create a relay as the following and call the function bind_params_by_name

def @main(%data1: Tensor[(128), float32], %data: Tensor[(1, 128, 12, 40), float32], %weight: Tensor[(128, 128, 1, 1), float32], %gamma: Tensor[(128), float32], %beta: Tensor[(128), float32], %moving_mean: Tensor[(128), float32], %moving_var: Tensor[(128), float32]) {
  %0 = reshape(%data1, newshape=[1, 128, 1, 1]);
  %1 = nn.conv2d(%data, %weight, padding=[0, 0, 0, 0]);
  %2 = add(%0, %1);
  %3 = nn.batch_norm(%2, %gamma, %beta, %moving_mean, %moving_var);
  %4 = %3.0;
  nn.relu(%4)
}

after bind params...
def @main(%data: Tensor[(1, 128, 12, 40), float32]) {
  %0 = reshape(meta[relay.Constant][0], newshape=[1, 128, 1, 1]);
  %1 = nn.conv2d(%data, meta[relay.Constant][1], padding=[0, 0, 0, 0]);
  %2 = add(%0, %1);
  %3 = nn.batch_norm(%2, meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5]);
  %4 = %3.0;
  nn.relu(%4)
}

and then, partition the graph by the pattern, and print the mod

After partition graph...
def @main(%data: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */) -> Tensor[(1, 128, 12, 40), float32] {
  @tvmgen_default_bmlib_main_0(%data) /* ty=Tensor[(1, 128, 12, 40), float32] */
}

def @tvmgen_default_bmlib_main_0(%bmlib_0_i1: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, Inline=1, global_symbol="tvmgen_default_bmlib_main_0", Compiler="bmlib", Primitive=1) -> Tensor[(1, 128, 12, 40), float32] {
  %5 = fn (%FunctionVar_0_0: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_1: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, %FunctionVar_0_2: Tensor[(128, 128, 1, 1), float32] /* ty=Tensor[(128, 128, 1, 1), float32] */, %FunctionVar_0_3: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_4: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_5: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %FunctionVar_0_6: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, PartitionedFromPattern="reshape_nn.conv2d_add_nn.batch_norm_TupleGetItem0_nn.relu_", Composite="bmlib.reshape_conv2d_add_bn_relu") -> Tensor[(1, 128, 12, 40), float32] {
    %0 = reshape(%FunctionVar_0_0, newshape=[1, 128, 1, 1]) /* ty=Tensor[(1, 128, 1, 1), float32] */;
    %1 = nn.conv2d(%FunctionVar_0_1, %FunctionVar_0_2, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 128, 12, 40), float32] */;
    %2 = add(%0, %1) /* ty=Tensor[(1, 128, 12, 40), float32] */;
    %3 = nn.batch_norm(%2, %FunctionVar_0_3, %FunctionVar_0_4, %FunctionVar_0_5, %FunctionVar_0_6) /* ty=(Tensor[(1, 128, 12, 40), float32], Tensor[(128), float32], Tensor[(128), float32]) */;
    %4 = %3.0 /* ty=Tensor[(1, 128, 12, 40), float32] */;
    nn.relu(%4) /* ty=Tensor[(1, 128, 12, 40), float32] */
  } /* ty=fn (Tensor[(128), float32], Tensor[(1, 128, 12, 40), float32], Tensor[(128, 128, 1, 1), float32], Tensor[(128), float32], Tensor[(128), float32], Tensor[(128), float32], Tensor[(128), float32]) -> Tensor[(1, 128, 12, 40), float32] */;
  %5(meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, %bmlib_0_i1, meta[relay.Constant][1] /* ty=Tensor[(128, 128, 1, 1), float32] */, meta[relay.Constant][2] /* ty=Tensor[(128), float32] */, meta[relay.Constant][3] /* ty=Tensor[(128), float32] */, meta[relay.Constant][4] /* ty=Tensor[(128), float32] */, meta[relay.Constant][5] /* ty=Tensor[(128), float32] */) /* ty=Tensor[(1, 128, 12, 40), float32] */
}

after I register relay.ext.bmlib, I call relay.build

tvm._ffi.register_func("relay.ext.bmlib", relay_to_runtime(tvm.target.Target("llvm")))

I have encountered an error

tvmgen_default_bmlib_main_0: num_args should be 8

I check and find in relay_to_runtime function

def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
        """Compile Relay functions to a runtime module using Tensor Expressions."""
        assert isinstance(partition, relay.Function)
        assert isinstance(partition.body, relay.Call)
        assert isinstance(partition.body.op, relay.Function)

        global_name = str(partition.attrs.global_symbol)
        comp_func = partition.body.op
        comp_name = comp_func.attrs["Composite"]
        assert comp_name in _LOWER_MAP
        assert isinstance(comp_func.body, relay.Call)

        op = comp_func.body
        inputs = []
        for i, param in enumerate(comp_func.params):
            inputs.append(
                te.placeholder(
                    param.checked_type.shape,
                    name=f"input_{i}",
                    dtype=param.checked_type.dtype,
                )
            )

        output = _LOWER_MAP[comp_name](op, inputs)
        prim_func = te.create_prim_func(inputs + [output])
        return tvm.build(prim_func, target=target, name=global_name)

inputs is the params of the function fn called by tvmgen_default_bmlib_main_0, not params of tvmgen_default_bmlib_main_0, but create_prim_func use inputs instead of params of tvmgen_default_bmlib_main_0, and caused the problem.

when I modify the pattern as the following, use is_constant() instead of wildcard(),

def reshape_conv2d_add_bn_relu_pattern2():
        """Create a reshape_conv2d_add_bn_relu_pattern
        """
        print("reshape_conv2d_add_bn_relu_pattern")
        pattern1 = is_op("reshape")(is_constant())
        pattern2 = is_op("nn.conv2d")(wildcard(), is_constant())
        pattern = is_op("add")(pattern1, pattern2)
        pattern = is_op("nn.batch_norm")(pattern, is_constant(), is_constant(), is_constant(), is_constant())
        pattern = is_tuple_get_item(pattern, 0)
        pattern = is_op("nn.relu")(pattern)
        return pattern

and partiton the graph again,

After partition graph...
def @main(%data: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */) -> Tensor[(1, 128, 12, 40), float32] {
  @tvmgen_default_bmlib_main_0(%data) /* ty=Tensor[(1, 128, 12, 40), float32] */
}

def @tvmgen_default_bmlib_main_0(%bmlib_0_i0: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, Inline=1, global_symbol="tvmgen_default_bmlib_main_0", Compiler="bmlib", Primitive=1) -> Tensor[(1, 128, 12, 40), float32] {
  %5 = fn (%FunctionVar_0_0: Tensor[(1, 128, 12, 40), float32] /* ty=Tensor[(1, 128, 12, 40), float32] */, PartitionedFromPattern="reshape_nn.conv2d_add_nn.batch_norm_TupleGetItem0_nn.relu_", Composite="bmlib.reshape_conv2d_add_bn_relu") -> Tensor[(1, 128, 12, 40), float32] {
    %0 = reshape(meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, newshape=[1, 128, 1, 1]) /* ty=Tensor[(1, 128, 1, 1), float32] */;
    %1 = nn.conv2d(%FunctionVar_0_0, meta[relay.Constant][1] /* ty=Tensor[(128, 128, 1, 1), float32] */, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 128, 12, 40), float32] */;
    %2 = add(%0, %1) /* ty=Tensor[(1, 128, 12, 40), float32] */;
    %3 = nn.batch_norm(%2, meta[relay.Constant][2] /* ty=Tensor[(128), float32] */, meta[relay.Constant][3] /* ty=Tensor[(128), float32] */, meta[relay.Constant][4] /* ty=Tensor[(128), float32] */, meta[relay.Constant][5] /* ty=Tensor[(128), float32] */) /* ty=(Tensor[(1, 128, 12, 40), float32], Tensor[(128), float32], Tensor[(128), float32]) */;
    %4 = %3.0 /* ty=Tensor[(1, 128, 12, 40), float32] */;
    nn.relu(%4) /* ty=Tensor[(1, 128, 12, 40), float32] */
  } /* ty=fn (Tensor[(1, 128, 12, 40), float32]) -> Tensor[(1, 128, 12, 40), float32] */;
  %5(%bmlib_0_i0) /* ty=Tensor[(1, 128, 12, 40), float32] */
}

this time, the params of tvmgen_default_bmlib_main_0 and the called function fn is the same, and It is fine after relay.build() and run the exported library.

I don’t understand why implementations in relay_to_runtime, why should collect all the arguments in the subgraph, but the first subgraph partitioned by pattern seems right logically.