A CallNode has relay ir but the primfunc has gone

Hi all,

I encountered an issue when doing some design. I added a TVM hash_table_create node to “Fuse” the TensorFlow HashTableV2 and LookupTableImportV2 ops. The hash_table_create node is created correctly in Relay IR, but it’s gone(or optimized out) in PrimFunc, so the codegen has no HashTable node which is unexpected. But from the printing, I can even see the TOPI is run during scheduling and optimizing phase.

Here is the Relay IR output:

...
  %517 = hash_table_create(%unknown_19/const_0, %unknown_19/const_1, key_dtype="object64", value_dtype="int64", shared_name="table_3487_load_1") /* Func/StatefulPartitionedCall/input/_505 */ /* ty=(Tensor[(1), custom[hashtable]64],) */;
  %518 = lookup_table_find(%517, %plat, 1 /* ty=int64 */, key_dtype="object64", value_dtype="int64", dtype="int64") /* StatefulPartitionedCall/functional_1/string_lookup/None_lookup_table_find/LookupTableFindV2 */ /* ty=Tensor[(1, 1), int64] */;
...
  %973 = take(%unknown_429, %518, axis=0) /* StatefulPartitionedCall/functional_1/embed_plat/embedding_lookup/Identity */ /* ty=Tensor[(1, 1, 8), float32] */;

But after tir optimization, in PrimFunc hash_table_create has gone.

, GlobalVar(tvmgen_default_fused_take_1): PrimFunc([placeholder, placeholder, T_take]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_take_1", "tir.noalias": (bool)1, "target": llvm -keys=cpu -libs=dnnl -link-params=0 -mattr=sse4.2,fma,avx512f -mcpu=x86-64 -opt-level=3} {
  parallel (ax0.ax1.fused, 0, 64) {
    T_take[ramp((ax0.ax1.fused*8), 1, 8)] = placeholder[(x8((min(max((int64)0, placeholder[ax0.ax1.fused]), (int64)49999)*(int64)8)) + int64x8(ramp(0, 1, 8)))]
  }
}
, GlobalVar(tvmgen_default_fused_lookup_table_find): PrimFunc([placeholder, placeholder, placeholder, hash_table_find]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_lookup_table_find", "tir.noalias": (bool)1, "target": llvm -keys=cpu -libs=dnnl -link-params=0 -mattr=sse4.2,fma,avx512f -mcpu=x86-64 -opt-level=3} {
  // attr [0] extern_scope = 0
  tir.tvm_call_packed("tvm.contrib.hashtable.find", tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(1), 0, 1, 0, 0), tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(1, 1), 0, 2, (object64)0, 0), tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(), 0, 0, (int64)0, 0), tir.tvm_stack_make_array(hash_table_find, tir.tvm_stack_make_shape(1, 1), 0, 2, (int64)0, 0))
}

The CallNode hash_table_create is an independant node, which has only 2 ConstNode as inputs. The output of it is the input of another CallNode node lookup_table_find. What’s the behavior inside tir pass to eliminate this node?

Per [Dev API] Possible to remove AST nodes?, tianqi answered “Instead, in order to remove something, we might want to simply return an Evaluate(0)(which is a nop) in post order transform”, I checked the code, RemoveNoOp seems suspicious, but I haven’t completely figured out the whole scenario.

I also checked RemoveNoOp, RemoveUnusedFunctions and some other tir passes, but didn’t find any hash_table_create primfunc printing.

Could you please shed some light on it? Thanks.

@junrushao @vinx13 could you please give some advice? Thank you!

I added more printing of the PrimFunc, I can see the hash_table_create’s PrimFunc is generated correctly:

====WENXIAN:lower_call() op= hash_table_create
=====select_implementation() impl: relay.OpImplementation(0xc612150) op.name: hash_table_create
====WENXIAN:hash_table_create TOPI called
=====select_implementation() impl: relay.OpImplementation(0xc612150) outputs: [Tensor(shape=[1], op.name=hash_table_create)]
====WENXIAN: ScheduleBuilder.Create(): intrp_fused_hash_table_create
====WENXIAN: LowerInternal() prim_func.undefined func_name:intrp_fused_hash_table_create
====WENXIAN: LowerSchedule() name:intrp_fused_hash_table_create
====WENXIAN: ScheduleToModule() PrimFunc: primfn(placeholder_2: handle, placeholder_3: handle, hash_table_create_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "intrp_fused_hash_table_create", "tir.noalias": True}
  buffers = {hash_table_create: Buffer(hash_table_create_2: Pointer(custom[hashtable]64), custom[hashtable]64, [1], []),
             placeholder: Buffer(placeholder_4: Pointer(object64), object64, [12], []),
             placeholder_1: Buffer(placeholder_5: Pointer(int64), int64, [12], [])}
  buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, hash_table_create_1: hash_table_create} {
  realize(hash_table_create, [0:1], True {
    attr [[placeholder_6: Buffer(placeholder_7: Pointer(object64), object64, [12], []), placeholder]] "buffer_bind_scope" = @tir.tvm_tuple(0, 12, dtype=handle);
    attr [[placeholder_8: Buffer(placeholder_9: Pointer(int64), int64, [12], []), placeholder_1]] "buffer_bind_scope" = @tir.tvm_tuple(0, 12, dtype=handle);
    attr [[hash_table_create_3: Buffer(hash_table_create_4: Pointer(custom[hashtable]64), custom[hashtable]64, [1], []), hash_table_create]] "buffer_bind_scope" = @tir.tvm_tuple(0, 1, dtype=handle);
    attr [0] "extern_scope" = 0;
    @tir.tvm_call_packed("tvm.contrib.hashtable.create", @tir.tvm_stack_make_array(placeholder_7, @tir.tvm_stack_make_shape(12, dtype=handle), 0, 1, 064, 0, dtype=handle), @tir.tvm_stack_make_array(placeholder_9, @tir.tvm_stack_make_shape(12, dtype=handle), 0, 1, 0i64, 0, dtype=handle), @tir.tvm_stack_make_array(hash_table_create_4, @tir.tvm_stack_make_shape(1, dtype=handle), 0, 1, 064, 0, dtype=handle), "object64", "int64", "table_3487_load_1", dtype=int32)
  })
}

====WENXIAN: LowerTensorExprMutator::LowerFunction() for primitive:fn (%p0: Tensor[(12), object64], %p1: Tensor[(12), int64], Primitive=1) -> (Tensor[(1), custom[hashtable]64],) {
  hash_table_create(%p0, %p1, key_dtype="object64", value_dtype="int64", shared_name="table_3487_load_1") /* ty=(Tensor[(1), custom[hashtable]64],) */
}====WENXIAN:LowerFunction() lowered primitive bound to:@intrp_fused_hash_table_create
====WENXIAN: TIRToPackedFunc() var->name_hint:intrp_fused_hash_table_create

The following is lookup_table_find’s PrimFunc:

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
====WENXIAN:lower_call() op= lookup_table_find
=====select_implementation() impl: relay.OpImplementation(0xcb54360) op.name: lookup_table_find
====WENXIAN:lookup_table_find TOPI called
=====select_implementation() impl: relay.OpImplementation(0xcb54360) outputs: [Tensor(shape=[1, 1], op.name=hash_table_find)]
====WENXIAN: ScheduleBuilder.Create(): tvmgen_default_fused_lookup_table_find
====WENXIAN: ScheduleToModule() PrimFunc: primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, hash_table_find_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_lookup_table_find", "tir.noalias": True}
  buffers = {hash_table_find: Buffer(hash_table_find_2: Pointer(int64), int64, [1, 1], []),
             placeholder_2: Buffer(placeholder_6: Pointer(int64), int64, [], []),
             placeholder: Buffer(placeholder_7: Pointer(custom[hashtable]64), custom[hashtable]64, [1], []),
             placeholder_1: Buffer(placeholder_8: Pointer(object64), object64, [1, 1], [])}
  buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, hash_table_find_1: hash_table_find} {
  realize(hash_table_find, [0:1, 0:1], True {
    attr [[placeholder_9: Buffer(placeholder_10: Pointer(custom[hashtable]64), custom[hashtable]64, [1], []), placeholder]] "buffer_bind_scope" = @tir.tvm_tuple(0, 1, dtype=handle);
    attr [[placeholder_11: Buffer(placeholder_12: Pointer(object64), object64, [1, 1], []), placeholder_1]] "buffer_bind_scope" = @tir.tvm_tuple(0, 1, 0, 1, dtype=handle);
    attr [[placeholder_13: Buffer(placeholder_14: Pointer(int64), int64, [], []), placeholder_2]] "buffer_bind_scope" = @tir.tvm_tuple(, dtype=handle);
    attr [[hash_table_find_3: Buffer(hash_table_find_4: Pointer(int64), int64, [1, 1], []), hash_table_find]] "buffer_bind_scope" = @tir.tvm_tuple(0, 1, 0, 1, dtype=handle);
    attr [0] "extern_scope" = 0;
    @tir.tvm_call_packed("tvm.contrib.hashtable.find", @tir.tvm_stack_make_array(placeholder_10, @tir.tvm_stack_make_shape(1, dtype=handle), 0, 1, 064, 0, dtype=handle), @tir.tvm_stack_make_array(placeholder_12, @tir.tvm_stack_make_shape(1, 1, dtype=handle), 0, 2, 064, 0, dtype=handle), @tir.tvm_stack_make_array(placeholder_14, @tir.tvm_stack_make_shape(, dtype=handle), 0, 0, 0i64, 0, dtype=handle), @tir.tvm_stack_make_array(hash_table_find_4, @tir.tvm_stack_make_shape(1, 1, dtype=handle), 0, 2, 0i64, 0, dtype=handle), dtype=int32)
  })
}

====WENXIAN: LowerTensorExprMutator::LowerFunction() for primitive:fn (%p0: (Tensor[(1), custom[hashtable]64],), %p1: Tensor[(1, 1), object64], %p2: int64, Primitive=1, hash="7e77540a249cf6a1") -> Tensor[(1, 1), int64] {
  lookup_table_find(%p0, %p1, %p2, key_dtype="object64", value_dtype="int64", dtype="int64") /* ty=Tensor[(1, 1), int64] */
}====WENXIAN:LowerFunction() lowered primitive bound to:@tvmgen_default_fused_lookup_table_find
llvm target triple: x86_64-unknown-linux-gnu
====WENXIAN: tvmgen_default_fused_lookup_table_find inputs eid:1929
====WENXIAN: tvmgen_default_fused_lookup_table_find inputs eid:413
====WENXIAN: tvmgen_default_fused_lookup_table_find inputs eid:1930

I can’t see any differences between these two nodes, but it seems hash_table_create() goes to Relay interpreter, while lookup_table_find() doesn’t.

After debugging, I figure it might because this node has been optimized away by Interpreter, given the node has 2 Constant Nodes as inputs. While in my case, it’s not an expected behavior. I created a new question to follow up this direction. How to skip executing a CallNode in Relay Interpreter? Would you please correct me if anything wrong. Thanks. @jroesch @yuchenj @junrushao @vinx13

1 Like

After doing some investigating on Relay pass codebase, I figure it might be optimized out by ConstantFold pass.

    if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) {
      // At least one non-constant argument.
      return std::move(post_call);
    }
    // During evaluation we have obviously lost all on_device annotations. However any
    // on_device wrapping this call will be left in place.
    return ConstEvaluate(post_call);
  }

I resolved the problem, thanks all!