Hi all,
I encountered an issue when doing some design. I added a TVM hash_table_create node to “Fuse” the TensorFlow HashTableV2 and LookupTableImportV2 ops. The hash_table_create node is created correctly in Relay IR, but it’s gone(or optimized out) in PrimFunc, so the codegen has no HashTable node which is unexpected. But from the printing, I can even see the TOPI is run during scheduling and optimizing phase.
Here is the Relay IR output:
...
%517 = hash_table_create(%unknown_19/const_0, %unknown_19/const_1, key_dtype="object64", value_dtype="int64", shared_name="table_3487_load_1") /* Func/StatefulPartitionedCall/input/_505 */ /* ty=(Tensor[(1), custom[hashtable]64],) */;
%518 = lookup_table_find(%517, %plat, 1 /* ty=int64 */, key_dtype="object64", value_dtype="int64", dtype="int64") /* StatefulPartitionedCall/functional_1/string_lookup/None_lookup_table_find/LookupTableFindV2 */ /* ty=Tensor[(1, 1), int64] */;
...
%973 = take(%unknown_429, %518, axis=0) /* StatefulPartitionedCall/functional_1/embed_plat/embedding_lookup/Identity */ /* ty=Tensor[(1, 1, 8), float32] */;
But after tir optimization, in PrimFunc hash_table_create has gone.
, GlobalVar(tvmgen_default_fused_take_1): PrimFunc([placeholder, placeholder, T_take]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_take_1", "tir.noalias": (bool)1, "target": llvm -keys=cpu -libs=dnnl -link-params=0 -mattr=sse4.2,fma,avx512f -mcpu=x86-64 -opt-level=3} {
parallel (ax0.ax1.fused, 0, 64) {
T_take[ramp((ax0.ax1.fused*8), 1, 8)] = placeholder[(x8((min(max((int64)0, placeholder[ax0.ax1.fused]), (int64)49999)*(int64)8)) + int64x8(ramp(0, 1, 8)))]
}
}
, GlobalVar(tvmgen_default_fused_lookup_table_find): PrimFunc([placeholder, placeholder, placeholder, hash_table_find]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_lookup_table_find", "tir.noalias": (bool)1, "target": llvm -keys=cpu -libs=dnnl -link-params=0 -mattr=sse4.2,fma,avx512f -mcpu=x86-64 -opt-level=3} {
// attr [0] extern_scope = 0
tir.tvm_call_packed("tvm.contrib.hashtable.find", tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(1), 0, 1, 0, 0), tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(1, 1), 0, 2, (object64)0, 0), tir.tvm_stack_make_array(placeholder, tir.tvm_stack_make_shape(), 0, 0, (int64)0, 0), tir.tvm_stack_make_array(hash_table_find, tir.tvm_stack_make_shape(1, 1), 0, 2, (int64)0, 0))
}
The CallNode hash_table_create is an independant node, which has only 2 ConstNode as inputs. The output of it is the input of another CallNode node lookup_table_find. What’s the behavior inside tir pass to eliminate this node?
Per [Dev API] Possible to remove AST nodes?, tianqi answered “Instead, in order to remove something, we might want to simply return an Evaluate(0)(which is a nop) in post order transform”, I checked the code, RemoveNoOp seems suspicious, but I haven’t completely figured out the whole scenario.
I also checked RemoveNoOp, RemoveUnusedFunctions and some other tir passes, but didn’t find any hash_table_create primfunc printing.
Could you please shed some light on it? Thanks.