Cannot allocate memory symbolic tensor shape [?, 1] if tf.import_graph_def() used

i am trying to run a DCN model with tvm. the origin model has dynamic shapes input (dynamic batch size), and is exported from tensorflow as a .pb file. my code is like:

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
        graph_def = tf_compat_v1.GraphDef()
        graph_def.ParseFromString(f.read())
        
        # if i uncomment this line, i will get Cannot allocate memory symbolic tensor shape [?, 1] error
        # tf.import_graph_def(graph_def, name='')

shape_dict = {
        'wt': input_shape,
        'id': input_shape
    }

layout = "NCHW"
mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
target = tvm.target.cuda()
opt_level = 3
with tvm.transform.PassContext(opt_level=opt_level):
      lib = relay.build(mod, target=target, params=params)

m = graph_runtime.GraphModule(lib["default"](ctx))
 #some_random_test_data
m.set_input('wt', wt_val)
m.set_input('id', id_val)
m.run()
out = m.get_output(0)
LOG.info(f'out: {out}')

if i uncomment the line tf.import_graph_def(), i will get ‘Cannot allocate memory symbolic tensor shape [?, 1] error’. more error log is here:

DEBUG:autotvm:Finish loading 825 records
Traceback (most recent call last):
  File "dcn_cbg_tvm.py", line 114, in <module>
    main()
  File "dcn_cbg_tvm.py", line 97, in main
    lib = relay.build(mod, target=target, target_host=None, params=params)
  File "/tvm/python/tvm/relay/build_module.py", line 275, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/tvm/python/tvm/relay/build_module.py", line 138, in build
    self._build(mod, target, target_host)
  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x6f) [0x7f3916653daf]
  [bt] (7) /tvm/build/libtvm.so(tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*)+0xc2) [0x7f39165fb9c2]
  [bt] (6) /tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&)+0x8b) [0x7f39166aabcb]
  [bt] (5) /tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x6f) [0x7f3916653daf]
  [bt] (4) /tvm/build/libtvm.so(tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*)+0x1b5) [0x7f39165fbab5]
  [bt] (3) /tvm/build/libtvm.so(tvm::relay::StorageAllocator::CreateToken(tvm::RelayExprNode const*, bool)+0x185) [0x7f39165fb5d5]
  [bt] (2) /tvm/build/libtvm.so(tvm::relay::StorageAllocator::Request(tvm::relay::StorageToken*)+0x34) [0x7f39165fa7a4]
  [bt] (1) /tvm/build/libtvm.so(tvm::relay::StorageAllocator::GetMemorySize(tvm::relay::StorageToken*)+0x296) [0x7f39165fa276]
  [bt] (0) /tvm/build/libtvm.so(+0x19834d8) [0x7f39165f84d8]
  File "/tvm/src/relay/backend/graph_plan_memory.cc", line 292
TVMError:
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: pval != nullptr == false: Cannot allocate memory symbolic tensor shape [?, 1]

if i comment it, the model can be correctly build and run:

DEBUG:autotvm:Finish loading 825 records
INFO:compile_engine:Using injective.cpu for reshape based on highest priority (10)
INFO:compile_engine:Using argwhere.generic for argwhere based on highest priority (10)
...
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('dense_small_batch.cuda', ('TENSOR', (10, 1878), 'float32'), ('TENSOR', (1, 1878), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
INFO:compile_engine:Using dense_small_batch.cuda for nn.dense based on highest priority (10)
INFO:compile_engine:Using injective.cuda for add based on highest priority (10)
...
INFO:compile_engine:Using reduce.cuda for sum based on highest priority (10)
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('dense_small_batch.cuda', ('TENSOR', (10, 1750), 'float32'), ('TENSOR', (1, 1750), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
INFO:compile_engine:Using dense_small_batch.cuda for nn.dense based on highest priority (10)
...
INFO:compile_engine:Using injective.cuda for add based on highest priority (10)
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('dense_small_batch.cuda', ('TENSOR', (10, 256), 'float32'), ('TENSOR', (128, 256), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
INFO:compile_engine:Using dense_small_batch.cuda for nn.dense based on highest priority (10)
..
INFO:compile_engine:Using injective.cuda for nn.relu based on highest priority (10)
INFO:__main__:out: [1.7003021e-01 3.2966433e-03 2.5015650e-04 3.9903298e-06 8.5008228e-03

1 Like

when i use this model to do benchmark, i found a strange performance drop problem.
if i comment print(f'out: {out}'), time cost max_top99 will raise from 0.45ms to 4.8ms at batch size 100, and get worse at batch size 500, 1000.

benchmark code part:

  for i in range(test_round):
            # prepare test data
            rng = np.random.default_rng()
            wt_val = 50 * (rng.standard_normal(size=input_shape, dtype=np.float32) + 1) / 2
            idx_val = rng.integers(1000, size=input_shape, dtype=np.int32)

            t1 = time.time()

            m.set_input('wt', wt_val)
            m.set_input('id', idx_val)
            m.run()
            out = m.get_output(0)

            t2 = time.time()
            time_cost = t2 - t1
            statistics[batch_size].append(time_cost)
            # if i comment this line, time_cost max_top99 will raise from 0.45ms from 4.8ms!
            print(f'out: {out}')
            print(f'time cost: {time_cost}')

    for batch_size in statistics:
        tc_list = sorted(statistics[batch_size], reverse=True)
        tc_avg = (sum(tc_list) * 1000 / test_round)
        max_top99 = tc_list[int(test_round * 0.01)] * 1000
        max_top95 = tc_list[int(test_round * 0.05)] * 1000
        print(
            f'batch size: {batch_size}, time cost avg: {tc_avg:.4f}ms, 99% in {max_top99}ms, 95% in {max_top95}ms')

test result:

INFO:__main__:batch size: 1, time cost avg: 1.7734ms, 99% in 0.3965ms, 95% in 0.3185ms
INFO:__main__:batch size: 10, time cost avg: 1.8373ms, 99% in 0.4249ms, 95% in 0.3586ms
INFO:__main__:batch size: 100, time cost avg: 1.9053ms, 99% in 0.4387ms, 95% in 0.3901ms
INFO:__main__:batch size: 500, time cost avg: 2.4915ms, 99% in 0.6819ms, 95% in 0.6292ms
INFO:__main__:batch size: 1000, time cost avg: 2.9930ms, 99% in 0.8507ms, 95% in 0.8247ms

after comment the print statement:

INFO:__main__:batch size: 1, time cost avg: 1.7674ms, 99% in 0.8605ms, 95% in 0.3073ms
INFO:__main__:batch size: 10, time cost avg: 2.0924ms, 99% in 0.8798ms, 95% in 0.6173ms
INFO:__main__:batch size: 100, time cost avg: 5.6550ms, 99% in 4.8485ms, 95% in 4.8087ms
INFO:__main__:batch size: 500, time cost avg: 20.3439ms, 99% in 23.4420ms, 95% in 23.3698ms
INFO:__main__:batch size: 1000, time cost avg: 38.9389ms, 99% in 46.8240ms, 95% in 38.2991ms

can anyone help? really appreciate.
@tqchen

after validation, I found this print() performance issue is related with async execution. similar topics here: https://discuss.tvm.apache.org/t/how-could-we-request-a-inference-synchronously/1135 https://discuss.tvm.apache.org/t/how-to-make-get-output-function-faster/5005/2

if print() used, time cost measured is async kernel launch time, with no result copy from device. if i add ‘context.sync()’,time cost will align to the case ‘no print()’.

the reason print() cause this problem may need dig in a little bit more.