Can TVM support BERT model inference or not?

HI,

  I'm trying to run BERT model in TVM, but it fails. 
  My BERT model and test data is forked from https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/model/bertsquad-12.tar.gz. 
   ONNX converted to IR by "mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)" is ok, but it fails when implement "lib = relay.build(mod, target=target, params=params)". 
    
  The failed log is as below and my target device is llvm. 

"Input shapes: {‘unique_ids_raw_output___9:0’: (1,), ‘segment_ids:0’: (1, 256), ‘input_mask:0’: (1, 256), ‘input_ids:0’: (1, 256)} Importing graph from ONNX to TVM Relay IR … before GraphProto after GraphProto before g.from_onnx after g.from_onnx Compiling graph from Relay IR to llvm … Caught an exception Traceback (most recent call last): 99: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 98: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 97: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 96: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 95: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 94: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 93: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 92: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 91: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 90: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 89: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 88: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 87: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 86: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 85: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 84: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 83: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 82: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 81: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 80: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 79: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 78: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 77: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 76: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 75: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 74: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 73: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 72: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 71: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 70: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 69: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 68: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 67: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 66: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 65: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 64: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 63: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 62: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 61: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 60: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 59: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 58: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 57: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 56: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 55: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 54: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 53: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 52: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 51: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 50: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 49: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 48: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 47: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 46: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 45: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 44: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 43: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 42: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 41: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 40: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 39: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 38: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 37: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 36: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 35: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 34: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 33: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 32: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 31: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 30: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 29: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 28: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 27: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 26: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 25: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 24: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 23: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 22: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 21: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 20: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 19: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 18: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 17: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 16: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 15: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 14: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 13: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 12: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 11: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 10: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 9: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 8: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 7: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 6: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 5: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 4: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) 3: tvm::relay::StorageAllocaBaseVisitor::GetToken(tvm::RelayExpr const&) 2: tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&) 1: tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) 0: tvm::relay::StorageAllocator::VisitExpr_(tvm::relay::CallNode const*) File “/TVM/src/relay/backend/graph_plan_memory.cc”, line 317 TVMError:


An error occurred during the execution of TVM. For more information, please see: Handle TVM Errors — tvm 0.9.dev114+g278173c18 documentation

Check failed: args.size() == 1U (2 vs. 1)

TVM can support BERT inference, but I’ve occasionally run into problems importing models, sometimes I’ve had issues with data sizes etc.

The most stable way of running BERT models I’ve found is by following the template set out in this tutorial, dropping the sparse transformation.

If you can get a version of your model from HF Transformers then that would minimize headaches.

A reproducible example code could help identify the source of this issue though.

thanks, Wheest. Is there any such example code link for my reference?

See for example my script https://github.com/masahi/torchscript-to-tvm/blob/master/transformers/test_bert.py