How to add a CallNode as an input for a relay.Function?

Hello. I’m a fresh user of TVM. In order to have a hashtable init along with graph runtime, I want to:

  1. split the HashTableV2 CallNodes in the graph which output is connecting with the LookupTableFindV2 CallNode.
  2. Change the LookupTableFindV2 node’s input with the Var.
  3. (a) create a new CallNode(op_name=table_initialize), the inputs are all the HashTableV2 CallNodes.
  4. (a) Let the table_initialize CallNode be the output of a new global func “init”.
  5. (b) create a new relay.Function, the inputs are al the HashTableV2 CallNodes.
  6. (b) Let the new function be the new global func
  7. The module has two funcs: main and init.
  8. Codegen for each func: main and init.

Please note that 3-4 and 5-6 are two alternatives. But I encountered a problem when adding a new Function. (take 5-6)

g._hashtable_nodes is all the hashtable nodes: I added an _expr.var here which is wrong. How can I add a _expr.Call (or a CallNode) for each HashTableV2 node?

            if "HashTableV2" in node.op:
                self._hashtable_nodes.append(_expr.var(node.name, shape=(1, 1), dtype="object"))
    main_func, params = g.from_tensorflow(graph_def,
                                     layout,
                                     shape,
                                     outputs,
                                     gdef_lib=graph_def_library,
                                     init_phase=False)
    module.mod["main"] = main_func
    module.params.update(params)

    if len(g._hashtable_nodes):
        data_shape = (1, 1)
        data = tvm.relay.var("data", tvm.relay.TensorType(data_shape, "object"))
        init_func = tvm.relay.Function(g._hashtable_nodes, g._hashtable_nodes[0])
        module.mod["init"] = init_func

The topo is something like this: C is the HashTableV2 node. Op is the new Function or CallNode.

Please correct me if any understanding is incorrect.

Hi @wenxian,

The parameters of relay.Function is an array of relay.Var, which means they cannot be CallNode: https://github.com/apache/tvm/blob/main/include/tvm/relay/function.h#L42.

edit: Noticed you are creating a Function with a list of Vars. Could you elaborate on “I added an _expr.var here which is wrong”?

@yuchenj Thanks for the answer. I think I need to create a CallNode. How can I write a dummy CallNode, with no specific function, the only purpose is to set all the HashTable’s node as input vars, so as to keep the HashTableV2 CallNode not being optimized out? (because it is not connecting with main func directly/indirectly)

One way to do this is to create a dummy operator, and implement a special lowering logic for it.

I think what you mean is you need two global functions – main and init, which you want to invoke separately. If it’s the case, you don’t need to call “init” in “main”, so you don’t need to create a CallNode for it.

1 Like

Thanks @yuchenj, I’ll try it!