Confusing Behavior with Graph Partitioning

Summary: I am trying to develop a BYOC integration that pattern-matches for certain function calls, which seems like a straightforward application of the system. However, I am running into a strange bug at the stage of graph partitioning and am not sure how I can fix it. Please pardon the lengthy explanation, but shows how to reproduce precisely the bug I had. @zhiics @comaniac, I would be glad to contribute a fix but I do not know exactly what to change.

Here is the pattern I am trying to match for (calls with an attribute, with up to 5 args):

from tvm.relay import dataflow_pattern as df

# call with 0 args
call_pattern = ((df.wildcard())()).has_attr({"match_key": "my_library"})
for i in range(1, 6):
    call_pattern = call_pattern | (df.wildcard()(*[df.wildcard() for j in range(i)]).has_attr({"match_key": "my_library"}))

@tvm.relay.op.contrib.register_pattern_table("match_key")
def call_pattern_table():
    return [
        ('match_key', call_pattern, lambda _: True)
    ]

I can produce a function that matches this pattern like so:

def wrap_call(caller, *args):
    attr =  tvm.ir.make_node("DictAttrs", **{"match_key": "my_library"})
    return relay.Call(caller, args, attrs=attr)

I apply this pattern to an example module as follows (per the tutorial):

def transform_mod(mod):
    pattern_table = call_pattern_table()
    mod = transform.MergeComposite(pattern_table)(mod)
    mod = transform.AnnotateTarget("match_key")(mod)
    mod = transform.MergeCompilerRegions()(mod)
    mod = transform.PartitionGraph()(mod)
    return mod

It works for an example program that consists of one matching call:

def init_mod():
    mod = tvm.IRModule()
    for i in range(6):
        mod[f'library_call_{i}'] = relay.Function([relay.Var(f'arg_{j}', type_annotation=relay.scalar_type('int32')) for j in range(i)], relay.const(0, dtype='int32'), ret_type=relay.scalar_type('int32'))
    return mod

mod = init_mod()
mod['main'] = relay.Function([], wrap_call(mod.get_global_var('lib_call_4', *[relay.const(0, dtype='int32') for i in range(4)]))
print(transform_mod(mod)) # works

However, if I create a function that uses two tagged calls, then I get an error at the partition graphs phase:

mod = init_mod()
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
mod['main'] = relay.Function([],
    relay.Let(x,
        wrap_call(mod.get_global_var('lib_call_3'),
                  *[relay.const(0, dtype='int32') for i in range(3)]),
        relay.Let(y, 
            relay.const(0, dtype='int32'),
            relay.Let(z,
                wrap_call(mod.get_global_var('lib_call_2'),
                          x, y),
                z))))
print(transform_mod(mod)) # fails

The resulting error is that there are free variables remaining. (Check failed: fv.size() == 0 (2 vs. 0) : There are free variables: [Var(match_key_4_i0, ty=TensorType([], int32)), Var(match_key_8_i0, ty=TensorType([], int32))].)

I inserted some prints in partition_graph.cc and saw that this was the main_function being produced after the PartitionGraph pass (before it was inserted into the new module, which results in the error):

FunctionNode([], TensorType([], int32), 
    LetNode(
        Var(x, ty=TensorType([], int32)), 
        Var(match_key_4_i0, ty=TensorType([], int32)), 
        LetNode(Var(y, ty=TensorType([], int32)), Constant(0), 
            LetNode(Var(z, ty=TensorType([], int32)), 
                Var(match_key_8_i0, ty=TensorType([], int32)), 
                Var(z, ty=TensorType([], int32))))), [], (nullptr))

(Note the free variables.)

This is what the main() looks like after the merging regions pass (this looks correct; it should not be possible to merge the regions):

def @main() -> int32 {
  %0 = annotation.compiler_begin(@library_call_3, meta[relay.attrs.CompilerAttrs][0]) /* ty=fn (int32, int32, int32) -> int32 */;
  %1 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][1]) /* ty=int32 */;
  %2 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][2]) /* ty=int32 */;
  %3 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][3]) /* ty=int32 */;
  %4 = fn (%FunctionVar_1_0: fn (int32, int32, int32) -> int32, %FunctionVar_1_1: int32, %FunctionVar_1_2: int32, %FunctionVar_1_3: int32, PartitionedFromPattern="Call_", Composite="match_key") -> int32 {
    %FunctionVar_1_0(%FunctionVar_1_1, %FunctionVar_1_2, %FunctionVar_1_3, __dict__=meta[Map][0]) /* ty=int32 */
  };
  %5 = %4(%0, %1, %2, %3) /* ty=int32 */;
  %6 = annotation.compiler_end(%5, meta[relay.attrs.CompilerAttrs][4]) /* ty=int32 */;
  let %x: int32 = annotation.compiler_begin(%6, meta[relay.attrs.CompilerAttrs][5]) /* ty=int32 */;
  let %y: int32 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][6]) /* ty=int32 */;
  %7 = annotation.compiler_begin(@library_call_2, meta[relay.attrs.CompilerAttrs][7]) /* ty=fn (int32, int32) -> int32 */;
  %8 = annotation.compiler_begin(%x, meta[relay.attrs.CompilerAttrs][8]) /* ty=int32 */;
  %9 = annotation.compiler_begin(%y, meta[relay.attrs.CompilerAttrs][9]) /* ty=int32 */;
  %10 = fn (%FunctionVar_0_0: fn (int32, int32) -> int32, %FunctionVar_0_1: int32, %FunctionVar_0_2: int32, PartitionedFromPattern="Call_", Composite="match_key") -> int32 {
    %FunctionVar_0_0(%FunctionVar_0_1, %FunctionVar_0_2, __dict__=meta[Map][0]) /* ty=int32 */
  };
  %11 = %10(%7, %8, %9) /* ty=int32 */;
  %12 = annotation.compiler_end(%11, meta[relay.attrs.CompilerAttrs][10]) /* ty=int32 */;
  let %z: int32 = annotation.compiler_begin(%12, meta[relay.attrs.CompilerAttrs][11]) /* ty=int32 */;
  %z
}

Please let me know if there is anything about the syntax I am using that is unsupported or known to have been wrong. I would be happy to provide my code in a single Python file to make it easier to work with.

I believe this error has something to do with the implementation of Partioner::Rewrite_(const CallNode*, const Expr&), since in that function, there is the code that produces the free vars but the branch that is supposed to include their definition is never run for the regions corresponding to those two free vars.

I would greatly appreciate any assistance! As I said, I would be glad to contribute a fix but I am not sure how to fix partition_graph.cc to handle this case.

The annotations to the LetNode seems incorrect to me. Specifically, all LetNodes, including %x, %y, and %z, don’t have the corresponding compiler_end. IIUC, we should not annotate the LetNodes.

cc @zhiics

Interesting. I would be happy to make changes to the annotation pass if we can specify the intended behavior. (I do not know exactly how it should work.)

Started a PR for a test case here: https://github.com/apache/tvm/pull/7318. The behavior in the test case appears to be the same as exhibited above, so if the annotation pass needs to be changed, I can put the fix into the above PR.

Hi @slyubomirsky,

May I ask how would you end up with the let nodes ? Are you putting the graph to A-normal form prior to partitioning ?

I kind of agree with @comaniac with annotating let nodes.

So depending on the reason for the above question – would you be able to post a relay snippet that you would expect as the outcome post annotate target?

From a quick look at it, it seems let var annotation should not happen. After fixing that, annotate target will end up restricting the annotated regions to the body of the let and evaluated bodies (i.e., the let var) will be used by subsequent annotated regions. Moreover, if you intend to run merge compiler regions, body and its use will not be merged – I suppose it is your expectation ?

So lets will act as a control-flow semantic here. Am I understanding your requirement right here?

Thank you for the reply, @manupa-arm. To clarify, the let nodes were present in my initial program (I intend to do a sequence of function calls and bind the results to variables) and I did not use any pass to get rid of the let bindings. If you think I should use some pass to eliminate these bindings, I can do that, but that should probably be documented somewhere if so.

I am not sure what I expect the annotations for let nodes to be, which is why I ask (I could not find any documentation on the intended usage of the compiler_begin/end annotations). What I would expect in the end-to-end example is to get two partitioned functions, each containing one of the calls that matches the pattern – as you note, no merging should happen. In the example in my first post, there are two separate function calls that should match the pattern ( @lib_call_3(0, 0, 0) and @lib_call_2(%x, %y)), so I would expect the final output to contain two functions that each contain one of those function calls.

Hi @slyubomirsky ,

Please refer to this for further info about annotations : [RFC][BYOC] An extended graph partitioning flow

I guess in your snippet, these compiler_begins should not happen :

let %x: int32 = annotation.compiler_begin(%6, meta[relay.attrs.CompilerAttrs][5]) /* ty=int32 */;
let %y: int32 = annotation.compiler_begin(0 /* ty=int32 */, meta[relay.attrs.CompilerAttrs][6]) /* ty=int32 */;
.
.
.
let %z: int32 = annotation.compiler_begin(%12, meta[relay.attrs.CompilerAttrs][11]) /* ty=int32 */;

I am not aware of any pass that eliminates let bindings.

I feel this might be a potential bug as you noted in your PR. Out of curiosity, is this relay produced from a relay frontend (i.e. reading from ML model) or something you’ve written in Relay ?

One workaround, is if the calls does not have “effect” maybe they dont need the sequencing anymore than the data dependencies (as visible to relay) require it, then you may not need to introduce lets unless they are introduced by a frontend.

thoughts ? cc : @mbaret @dmitriy-arm

The let bindings were code that I was writing myself; they didn’t come from an existing model.

Thank you for directing my attention to the RFC. I will read it carefully and see if I can figure out the intended behavior for the test case I wrote.

So, here’s this for a proposal for what the behavior should be in this case: Based on my understanding of the RFC and the dataflow graph diagrams, I would expect the annotations for

# suppose %y is defined above
let %x = @f(%a, %b) # the call is tagged
  in
  (%x, %y)

(where the function call is what matches our pattern) to look like

%y1 = compiler_begin(%y, default);
%f1 = compiler_begin(@f, default);
%a1 = compiler_begin(%a, default);
%b1 = compiler_begin(%b, default);
%call = %f1(%a1, %b1);
%c1 = compiler_end(%call, my_label);
let %x = %c2 in
  %x1 = compiler_begin(%x, default);
  %t = (%x1, %y1);
  compiler_end(%t, default);

My reasoning is that the call node is a join point in the graph, so the inputs to the call node need to have a compiler_begin tag and the output needs to have a compiler_end tag. After the call, we have the tuple node (also a join point), for which the inputs will also have a compiler_begin tag and the output needs a compiler_end tag. Does that seem like a reasonable interpretation of the description in the RFC?

@slyubomirsky I think your proposal makes sense to me. The begins are actually tracking the inputs and ends are tracking the outputs (we can pack them into a tuple when there are multiple of them).

BTW, would it be better to convert them into graph form first, then annotate/partition, and finally convert it back to anf?

Yes, I worked it out on paper by writing out a graph. I’ll have to dig into the code for the annotation pass to see why this behavior isn’t what results.

On second thought, I’m not completely sure how let nodes should behave given the linked RFC. If effects aren’t involved, you can conceptually treat let %x = definition in body by replacing %x with its definition in every time it appears in the body (i.e., each instance of %x is an edge in the dataflow graph starting from the definition), but I’m not sure how to translate this to a dataflow graph and then get let nodes back.

Should the definition in the let node contain a compiler_begin annotation but the variable have the compiler_end annotation wherever it appears in the body? Or should each appearance of the variable in the body have its own compiler_begin annotation? And should the definition always contain a compiler_end annotation?

For example, if we have let %x = 0 in %x, right now the annotation pass would produce let %x = compiler_begin(0, default) in %x, which seems strange since there is only a compiler_begin tag but no end tag.

Would the correct behavior to put a compiler_end tag on each appearance of %x? For example,

let %x = compiler_begin(0, default) in
  compiler_end(%x, default)

Or should we put the compiler_end right into the definition, producing something like the following?

let %x = compiler_end(compiler_begin(0, default), default) in
  %x

@slyubomirsky I think the first annotation makes more sense. I think the pass was mainly used to annotate graph form relay programs. Let was not really handled well. This confuses me somehow now. For example, if we create a compiler_end for each appearance of %x, what the partitioned graph would like later on?

@slyubomirsky , I dont think we need compiler begins and ends for let var and they should be only in the bodies and values.

So this is assuming you are OK with external functions to be broken into values and their uses in the body separately. (i.e., only the pure data-flow segments will go into the external function body and the sequencing will remain in the main).

There is a relay pass called ToANormalForm that you can run after the partitioning if you want to introduce them for the whole graph if you dont mind putting lets to all calls present in the main. However, this will not respect if you had a specific ordering that you had before – say you had some fan-outs nodes where there is a possibility of different way of scheduling, this pass will just pick a order as the visitor traverse them.