I am using these patterns with relay.transform.MergeComposite
. I do not think concat(wildcard())
ropes the input tuple into the pattern. This is the result when executing the code below:
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import wildcard, is_op
data = []
for i in range(2):
data.append(relay.var("in_{}".format(i), shape=(1, 4, 4, 3), dtype='float32'))
z = relay.concatenate(data, axis=1)
concat = relay.Function(relay.analysis.free_vars(z), z)
mod = tvm.IRModule.from_expr(concat)
concat_pattern = is_op('concatenate')(wildcard())
mod = relay.transform.MergeComposite([('concat_pattern', concat_pattern)])(mod)
print(mod)
Output:
def @main(%in_0: Tensor[(1, 4, 4, 3), float32], %in_1: Tensor[(1, 4, 4, 3), float32]) -> Tensor[(1, 8, 4, 3), float32] {
%0 = (%in_0, %in_1);
%1 = fn (%FunctionVar_0_0: (Tensor[(1, 4, 4, 3), float32], Tensor[(1, 4, 4, 3), float32]), PartitionedFromPattern="concatenate_", Composite="concat_pattern", Primitive=1) -> Tensor[(1, 8, 4, 3), float32] {
concatenate(%FunctionVar_0_0, axis=1) /* ty=Tensor[(1, 8, 4, 3), float32] */
};
%1(%0) /* ty=Tensor[(1, 8, 4, 3), float32] */
}