Summary: I am trying to develop a BYOC integration that pattern-matches for certain function calls, which seems like a straightforward application of the system. However, I am running into a strange bug at the stage of graph partitioning and am not sure how I can fix it. Please pardon the lengthy explanation, but shows how to reproduce precisely the bug I had. @zhiics @comaniac, I would be glad to contribute a fix but I do not know exactly what to change.
Here is the pattern I am trying to match for (calls with an attribute, with up to 5 args):
from tvm.relay import dataflow_pattern as df
# call with 0 args
call_pattern = ((df.wildcard())()).has_attr({"match_key": "my_library"})
for i in range(1, 6):
call_pattern = call_pattern | (df.wildcard()(*[df.wildcard() for j in range(i)]).has_attr({"match_key": "my_library"}))
@tvm.relay.op.contrib.register_pattern_table("match_key")
def call_pattern_table():
return [
('match_key', call_pattern, lambda _: True)
]
I can produce a function that matches this pattern like so:
def wrap_call(caller, *args):
attr = tvm.ir.make_node("DictAttrs", **{"match_key": "my_library"})
return relay.Call(caller, args, attrs=attr)
I apply this pattern to an example module as follows (per the tutorial):
def transform_mod(mod):
pattern_table = call_pattern_table()
mod = transform.MergeComposite(pattern_table)(mod)
mod = transform.AnnotateTarget("match_key")(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
return mod
It works for an example program that consists of one matching call:
def init_mod():
mod = tvm.IRModule()
for i in range(6):
mod[f'library_call_{i}'] = relay.Function([relay.Var(f'arg_{j}', type_annotation=relay.scalar_type('int32')) for j in range(i)], relay.const(0, dtype='int32'), ret_type=relay.scalar_type('int32'))
return mod
mod = init_mod()
mod['main'] = relay.Function([], wrap_call(mod.get_global_var('lib_call_4', *[relay.const(0, dtype='int32') for i in range(4)]))
print(transform_mod(mod)) # works
However, if I create a function that uses two tagged calls, then I get an error at the partition graphs phase:
mod = init_mod()
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
mod['main'] = relay.Function([],
relay.Let(x,
wrap_call(mod.get_global_var('lib_call_3'),
*[relay.const(0, dtype='int32') for i in range(3)]),
relay.Let(y,
relay.const(0, dtype='int32'),
relay.Let(z,
wrap_call(mod.get_global_var('lib_call_2'),
x, y),
z))))
print(transform_mod(mod)) # fails
The resulting error is that there are free variables remaining. (Check failed: fv.size() == 0 (2 vs. 0) : There are free variables: [Var(match_key_4_i0, ty=TensorType([], int32)), Var(match_key_8_i0, ty=TensorType([], int32))]
.)
I inserted some prints in partition_graph.cc
and saw that this was the main_function
being produced after the PartitionGraph
pass (before it was inserted into the new module, which results in the error):
FunctionNode([], TensorType([], int32),
LetNode(
Var(x, ty=TensorType([], int32)),
Var(match_key_4_i0, ty=TensorType([], int32)),
LetNode(Var(y, ty=TensorType([], int32)), Constant(0),
LetNode(Var(z, ty=TensorType([], int32)),
Var(match_key_8_i0, ty=TensorType([], int32)),
Var(z, ty=TensorType([], int32))))), [], (nullptr))
(Note the free variables.)
This is what the main()
looks like after the merging regions pass (this looks correct; it should not be possible to merge the regions):
def @main() -> int32 {
%0 = annotation.compiler_begin(@library_call_3, meta[relay.attrs.CompilerAttrs][0]) /* ty=fn (int32, int32, int32) -> int32 */;
%1 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][1]) /* ty=int32 */;
%2 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][2]) /* ty=int32 */;
%3 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][3]) /* ty=int32 */;
%4 = fn (%FunctionVar_1_0: fn (int32, int32, int32) -> int32, %FunctionVar_1_1: int32, %FunctionVar_1_2: int32, %FunctionVar_1_3: int32, PartitionedFromPattern="Call_", Composite="match_key") -> int32 {
%FunctionVar_1_0(%FunctionVar_1_1, %FunctionVar_1_2, %FunctionVar_1_3, __dict__=meta[Map][0]) /* ty=int32 */
};
%5 = %4(%0, %1, %2, %3) /* ty=int32 */;
%6 = annotation.compiler_end(%5, meta[relay.attrs.CompilerAttrs][4]) /* ty=int32 */;
let %x: int32 = annotation.compiler_begin(%6, meta[relay.attrs.CompilerAttrs][5]) /* ty=int32 */;
let %y: int32 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][6]) /* ty=int32 */;
%7 = annotation.compiler_begin(@library_call_2, meta[relay.attrs.CompilerAttrs][7]) /* ty=fn (int32, int32) -> int32 */;
%8 = annotation.compiler_begin(%x, meta[relay.attrs.CompilerAttrs][8]) /* ty=int32 */;
%9 = annotation.compiler_begin(%y, meta[relay.attrs.CompilerAttrs][9]) /* ty=int32 */;
%10 = fn (%FunctionVar_0_0: fn (int32, int32) -> int32, %FunctionVar_0_1: int32, %FunctionVar_0_2: int32, PartitionedFromPattern="Call_", Composite="match_key") -> int32 {
%FunctionVar_0_0(%FunctionVar_0_1, %FunctionVar_0_2, __dict__=meta[Map][0]) /* ty=int32 */
};
%11 = %10(%7, %8, %9) /* ty=int32 */;
%12 = annotation.compiler_end(%11, meta[relay.attrs.CompilerAttrs][10]) /* ty=int32 */;
let %z: int32 = annotation.compiler_begin(%12, meta[relay.attrs.CompilerAttrs][11]) /* ty=int32 */;
%z
}
Please let me know if there is anything about the syntax I am using that is unsupported or known to have been wrong. I would be happy to provide my code in a single Python file to make it easier to work with.
I believe this error has something to do with the implementation of Partioner::Rewrite_(const CallNode*, const Expr&)
, since in that function, there is the code that produces the free vars but the branch that is supposed to include their definition is never run for the regions corresponding to those two free vars.
I would greatly appreciate any assistance! As I said, I would be glad to contribute a fix but I am not sure how to fix partition_graph.cc
to handle this case.