[Performance] TVM - pytorch BERT on CPU

I do some experiments to test whether TVM can help accelerate BERT inference. I use Huggingface to load bert_base_uncased as my model and try to follow your tutorials. The following are my environment settings and the results. Can anyone have any ideas why TVM is not working with a great improvement?

Environment Settings

  • CPU: m5.2 instance on Amazon EC2
  • Host CPU: skylake-avx512
  • Requirements:
conda env create --file conda/build-environment.yaml (default in TVM) 
conda install --no-update-deps -y -c conda-forge openblas==0.3.12
conda install --no-update-deps -y -c intel mkl-include==2021.2.0
conda install --no-update-deps -y -c intel mkl==2021.2.0
pip install transformers==4.6.1 torch==1.7.1 decorator==5.0.9 attrs==20.2.0 tornado==6.1 xgboost==1.4.2 cloudpickle==1.6.0 psutil==5.8.0
  • config.cmake:
set(USE_LLVM ON)
set(USE_BLAS openblas)
set(USE_MKL ON)
set(USE_MKLDNN /home/chengpi/projects/dnnl_lnx_2.2.0_cpu_iomp)
set(USE_OPENMP intel)
set(USE_NNPACK ON)
set(NNPACK_PATH /home/chengpi/projects/NNPACK)
# other parameters are default

Experiment1 - Different Target

  • Q1: why -libs=cblas can have the best results rather than using mkl or mkl-dnn?
  • Q2: why -libs=mkl I will get the following error; however, when using MXnet instead of Pytorch the error is disappear.
AttributeError: Traceback (most recent call last):
  25: TVMFuncCall
  24: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  23: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  22: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  21: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::relay::Function)
  20: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  19: _ZZN3tvm5relay11ExprFunc
  18: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::TupleNode const*)
  17: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  16: _ZZN3tvm5relay11ExprFunc
  15: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  14: tvm::relay::backend::GraphExecutorCodegen::GraphAddCallNode(tvm::relay::CallNode const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, dmlc::any, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, dmlc::any> > >)
  13: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relay11ExprFunc
  11: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  10: tvm::relay::backend::GraphExecutorCodegen::GraphAddCallNode(tvm::relay::CallNode const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, dmlc::any, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, dmlc::any> > >)
  9: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relay11ExprFunc
  7: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  6: tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::CompileEngine&, tvm::relay::CCacheKey&>(tvm::relay::CompileEngine&, tvm::relay::CCacheKey&) const
  5: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::relay::CachedFunc (tvm::relay::CompileEngine, tvm::relay::CCacheKey)>::AssignTypedLambda<tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}>(tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  4: tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)
  3: tvm::relay::CreateSchedule(tvm::relay::Function const&, tvm::Target const&)
  2: tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)
  1: tvm::relay::OpImplementation::Schedule(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Target const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  File "/home/chengpi/projects/tvm/python/tvm/runtime/object.py", line 63, in __getattr__
    return _ffi_node_api.NodeGetAttr(self, name)
  File "/home/chengpi/projects/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: _ZNSt17_Function_handlerI
  1: tvm::NodeGetAttr(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::ReflectionVTable::GetAttr(tvm::runtime::Object*, tvm::runtime::String const&) const
  File "../src/node/reflection.cc", line 110
  File "/home/chengpi/projects/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/chengpi/projects/tvm/python/tvm/relay/op/strategy/generic.py", line 51, in wrapper
    return topi_schedule(outs)
  File "/home/chengpi/projects/tvm/python/tvm/autotvm/task/topi_integration.py", line 235, in wrapper
    return topi_schedule(cfg, outs, *args, **kwargs)
  File "/home/chengpi/projects/tvm/python/tvm/topi/x86/dense.py", line 338, in schedule_dense_mkl
    schedule_injective_from_existing(s, out)
  File "/home/chengpi/projects/tvm/python/tvm/topi/x86/injective.py", line 39, in schedule_injective_from_existing
    if len(sch[out].op.axis) >= 5:
  File "/home/chengpi/projects/tvm/python/tvm/runtime/object.py", line 65, in __getattr__
    raise AttributeError("%s has no attribute %s" % (str(type(self)), name))
AttributeError: ExternOp object has no attributed axis
During handling of the above exception, another exception occurred:

Experiment2 - Different Tune

  • Q1: why using target llvm -libs=cblas tuning seems to not working? And using target llvm we can see autoSchedule have the best performance.

Experiment3 - Different Length

  • Q1: When the sequence length is short, we can find tuning have a little improvement even beat the original pytorch results. However, when the length is long, we have worse performance than original pytorch. Can anyone explain what is going on?

I was wondering whether I forget something to setup for the best performance. I look forward to anyone’s help. Thanks a lot!!

3 Likes

The first suggestion is when running on CPUs, you should specify a more precise target to make sure the intrinsic are used. For example, you should use llvm -mcpu=skylake-avx512 <other-options> to make use of AVX512 on Skylake CPUs.

  • Q1: why -libs=cblas can have the best results rather than using mkl or mkl-dnn?

It needs detail profiling so I don’t know the root cause either.

  • Q2: why -libs=mkl I will get the following error; however, when using MXnet instead of Pytorch the error is disappear.

Looks like the error is caused by this PR. cc @kevinthesun

Experiment2 - Different Tune

How did you tune the model?

In addition, you might want to check out the blog post wrote by @haichen about running BERT with TVM on CPU.

https://medium.com/apache-mxnet/speed-up-your-bert-inference-by-3x-on-cpus-using-apache-tvm-9cf7776cd7f8

@comaniac Thanks for your reply!

Experiment1 - Different Target

Sorry, I didn’t write completely in the graph. I had used llvm -mcpu=skylake-avx512 for all the TVM experiments. And I had followed the blog post you provide, the results showing in the following. I am wondering why MXNet is not as fast as Pytorch; therefore, we can easily get a great improvement (2X-3X) when using TVM.

Experiment2 - Different Tune

I tune my model follow this repo, but my model is in Pytorch rather than MXNet. Also, I found a repo which had experiment for Pytorch BERT before, showing a little improvement (5%-10%). As a result, I think maybe we can only have little speed up in Pytorch BERT or I miss something to accelerate.

Experiment3 - Different Length

However, I find we can only have a little improvement when the sequence length is not too long.

I see. Yeah if you have used auto-scheduler to tune the model, as the result shown in the paper, the improvement should be about 10%-15% on Intel CPU.

For auto-scheduler/autotvm, did you re-tune the model or did you directly reuse the tuning logs provided in TLCBench?

In TLCBench, I also compared auto-scheduler vs. “llvm -libs=cblas” and found auto-scheduler is similar or slightly better than “llvm -libs=cblas”.

If you have access to AWS, could you also try to run the benchmark on c5.9xlarge? If you use c5.9xlarge, you can direct reuse the log files in TLCBench and see whether you can reproduce the latency numbers listed here. This is a good sanity check for your setup.

There can also be bugs when converting models from pytorch. I am not sure.

1 Like

@comaniac Yeap, we can have a little improvement. But I am curious why it is not working when the length increasing.

@merrymercy I do the experiments following your suggestions, showing in the folllowing.

Experiment1 - Different Target

I can reproduce your results with the latest TVM when I reusing your tuning logs (autoTVM got failed due to different version). However, it seems that Pytorch model got worse speed up than MXNet when using TVM. If I increase the sequence length, I still encounter the problem as the following.

Experiment3 - Different Length

When I increase the sequence length, original pytorch model can get the best performance. Also, MXNet is better than Pytorch when using TVM. Therefore, I try to plot their IR graph (after the transforms in the following) to find the problems. However, it is hard to distinguish which operator cause the problems, so I need your advices. Thanks for helping!!

new_mod = tvm.relay.transform.FastMath()(mod)
new_mod = tvm.relay.transform.EliminateCommonSubexpr()(new_mod)
BindPass = tvm.relay.transform.function_pass(lambda fn, new_mod, ctx: tvm.relay.build_module.bind_params_by_name(fn, params), opt_level=1)
new_mod = BindPass(new_mod)
new_mod = tvm.relay.transform.FoldConstant()(new_mod)
new_mod = tvm.relay.transform.CombineParallelBatchMatmul()(new_mod)
new_mod = tvm.relay.transform.CombineParallelDense(to_batch=False)(new_mod)
new_mod = tvm.relay.transform.SimplifyExpr()(new_mod)
new_mod = tvm.relay.transform.FoldConstant()(new_mod)

From your model graphs (really helpful!), we can see that the BERT implementations of PyTorch and MXNet are different. My first and no insight guess is MXNet implementation is more TVM friendly. My second guess is due to the difference between PyTorch and MXNet frontends, which may have different logic when converting ops. Either way, we need a deeper investigation to know the root cause.

Also cc @masahi you might be interested in this case.

Could it simply be that PyTorch BERT is just super fast? It uses JIT-generated ASM with FBGEMM. To see if PT frontend is not doing great job, you can convert the model to ONNX and try ONNX frontend.

@masahi I add ONNX for the experiments in the following and it seems using ONNX-runtime can get the best performance no matter the sequence length is (without tuning). I use ONNX-runtime with GraphOptimizationLevel.ORT_ENABLE_ALL showing in this link. Besides, I plot the IR graph for ONNX, which is quite complicated.

Experiment 1 - Different Target

Experiment 3 - Different Length

Also, I have some questions about AutoSchedule tuning.

  • Q1: @merrymercy I was confused that when I use AutoSchedule to tune TVM, can I use target like llvm -libs=cblas or I should use only llvm. I found this will give different tasks to tune.

  • Q2: @comaniac I think MXNet IR is more friendly than Pytorch IR for AutoSchedule tuning. I set the same parameters for tuning but Pytorch cannot get the results as MXNet (16ms for seq_len=128) The following are their tuning tasks and it seems quite different due to different IR graph. I still work on where the problem comes from, TVM front-end or original code implementation. But I think maybe TVM will have some transforms to generate similar IR graph even if from different framework.

    • 1st difference: MXNet will use nn.bias_add() and Pytorch will use relay.add(), which cause the tuning tasks not include this operation. (task 0,1,2,6)
    • 2nd difference: Their attention softmax operation have different shape, but I think this doesn’t cause too much latency difference (task 4)
# Tasks for Pytorch AutoSchedule Tuning (Target = llvm)
========== Task 0 (workload key: ["61f56dfd63fda28bc8bcf85739c8e9e3", 128, 3072, 768, 3072, 128, 768]) ==========
placeholder = PLACEHOLDER [128, 3072]
placeholder = PLACEHOLDER [768, 3072]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])

========== Task 1 (workload key: ["61f56dfd63fda28bc8bcf85739c8e9e3", 128, 768, 3072, 768, 128, 3072]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [3072, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])

========== Task 2 (workload key: ["61f56dfd63fda28bc8bcf85739c8e9e3", 128, 768, 768, 768, 128, 768]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [768, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])

========== Task 3 (workload key: ["d2a28fdf41e83222456f5a6e5bf8a24a", 12, 128, 128, 12, 64, 128, 12, 128, 64]) ==========
placeholder = PLACEHOLDER [12, 128, 128]
placeholder = PLACEHOLDER [12, 64, 128]
compute(b, i, j) += (placeholder[b, i, k]*placeholder[b, j, k])

========== Task 4 (workload key: ["868c2771b1610bdac0ac73167691f4eb", 1, 12, 128, 128, 1, 12, 128, 128]) ==========
placeholder = PLACEHOLDER [1, 12, 128, 128]
T_softmax_maxelem(i0, i1, i2) max= placeholder[i0, i1, i2, k]
T_softmax_delta(i0, i1, i2, i3) = (placeholder[i0, i1, i2, i3] - T_softmax_maxelem[i0, i1, i2])
T_fast_exp(ax0, ax1, ax2, ax3) = max((tir.reinterpret(tir.shift_left(int32((tir.floor(((max(min(T_softmax_delta[ax0, ax1, ax2, a ..(OMITTED).. max_delta[ax0, ax1, ax2, ax3], 88.3763f), -88.3763f)*1.4427f) + 0.5f))*0.693147f))) + 1f)), T_softmax_delta[ax0, ax1, ax2, ax3])
T_softmax_expsum(i0, i1, i2) += T_fast_exp[i0, i1, i2, k]
T_softmax_norm(i0, i1, i2, i3) = (T_fast_exp[i0, i1, i2, i3]/T_softmax_expsum[i0, i1, i2])

========== Task 5 (workload key: ["d2a28fdf41e83222456f5a6e5bf8a24a", 12, 128, 64, 12, 128, 64, 12, 128, 128]) ==========
placeholder = PLACEHOLDER [12, 128, 64]
placeholder = PLACEHOLDER [12, 128, 64]
compute(b, i, j) += (placeholder[b, i, k]*placeholder[b, j, k])

========== Task 6 (workload key: ["61f56dfd63fda28bc8bcf85739c8e9e3", 128, 768, 2304, 768, 128, 2304]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [2304, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])

========== Task 7 (workload key: ["2dde9ffcbf97381c0f0307643e09dac5", 1, 128, 768, 1, 128, 1]) ==========
placeholder = PLACEHOLDER [1, 128, 768]
placeholder_red(ax0, ax1, ax2) += placeholder[ax0, ax1, k2]
T_divide(ax0, ax1, ax2) = (placeholder_red[ax0, ax1, ax2]/768f)

========== Task 8 (workload key: ["dde89265d3f1a59075cee648386eac1e", 1, 128, 768, 1, 128, 1, 1, 128, 1]) ==========
placeholder = PLACEHOLDER [1, 128, 768]
placeholder = PLACEHOLDER [1, 128, 1]
T_subtract(ax0, ax1, ax2) = (placeholder[ax0, ax1, ax2] - placeholder[ax0, ax1, 0])
T_multiply(ax0, ax1, ax2) = (T_subtract[ax0, ax1, ax2]*T_subtract[ax0, ax1, ax2])
T_multiply_red(ax0, ax1, ax2) += T_multiply[ax0, ax1, k2]
T_divide(ax0, ax1, ax2) = (T_multiply_red[ax0, ax1, ax2]/768f)

========== Task 9 (workload key: ["9e3bd222d4f8d250aeadf2fef0b15f2b", 1, 768, 768, 768, 768, 1, 768]) ==========
placeholder = PLACEHOLDER [1, 768]
placeholder = PLACEHOLDER [768, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [768]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])
T_minimum(ax0, ax1) = min(T_add[ax0, ax1], 9f)
T_maximum(ax0, ax1) = max(T_minimum[ax0, ax1], -9f)
T_fast_tanh(ax0, ax1) = ((T_maximum[ax0, ax1]*(((T_maximum[ax0, ax1]*T_maximum[ax0, ax1])*(((T_maximum[ax0, ax1]*T_maximum[ax0,  ..(OMITTED).. *T_maximum[ax0, ax1])*(((T_maximum[ax0, ax1]*T_maximum[ax0, ax1])*1.19826e-06f) + 0.000118535f)) + 0.00226843f)) + 0.00489353f))
# Tasks for MXNet AutoSchedule Tuning (Target = llvm)
========== Task 0  (workload key: ["9847f8cc0b305137f49f2c5c0c8ab25d", 128, 3072, 768, 3072, 768, 128, 768]) ==========
placeholder = PLACEHOLDER [128, 3072]
placeholder = PLACEHOLDER [768, 3072]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [768]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])

========== Task 1  (workload key: ["9847f8cc0b305137f49f2c5c0c8ab25d", 128, 768, 3072, 768, 3072, 128, 3072]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [3072, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [3072]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])

========== Task 2  (workload key: ["9847f8cc0b305137f49f2c5c0c8ab25d", 128, 768, 768, 768, 768, 128, 768]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [768, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [768]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])

========== Task 3  (workload key: ["d2a28fdf41e83222456f5a6e5bf8a24a", 12, 128, 128, 12, 64, 128, 12, 128, 64]) ==========
placeholder = PLACEHOLDER [12, 128, 128]
placeholder = PLACEHOLDER [12, 64, 128]
compute(b, i, j) += (placeholder[b, i, k]*placeholder[b, j, k])

========== Task 4  (workload key: ["4b5e216f8244b4e8f7b6543c4a9087e5", 1536, 128, 1536, 128]) ==========
placeholder = PLACEHOLDER [1536, 128]
T_softmax_maxelem(i0) max= placeholder[i0, k]
T_softmax_delta(i0, i1) = (placeholder[i0, i1] - T_softmax_maxelem[i0])
T_fast_exp(ax0, ax1) = max((tir.reinterpret(tir.shift_left(int32((tir.floor(((max(min(T_softmax_delta[ax0, ax1], 88.3763f), -88. ..(OMITTED).. oor(((max(min(T_softmax_delta[ax0, ax1], 88.3763f), -88.3763f)*1.4427f) + 0.5f))*0.693147f))) + 1f)), T_softmax_delta[ax0, ax1])
T_softmax_expsum(i0) += T_fast_exp[i0, k]
T_softmax_norm(i0, i1) = (T_fast_exp[i0, i1]/T_softmax_expsum[i0])

========== Task 5  (workload key: ["d2a28fdf41e83222456f5a6e5bf8a24a", 12, 128, 64, 12, 128, 64, 12, 128, 128]) ==========
placeholder = PLACEHOLDER [12, 128, 64]
placeholder = PLACEHOLDER [12, 128, 64]
compute(b, i, j) += (placeholder[b, i, k]*placeholder[b, j, k])

========== Task 6  (workload key: ["9847f8cc0b305137f49f2c5c0c8ab25d", 128, 768, 2304, 768, 2304, 128, 2304]) ==========
placeholder = PLACEHOLDER [128, 768]
placeholder = PLACEHOLDER [2304, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [2304]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])

========== Task 7  (workload key: ["2dde9ffcbf97381c0f0307643e09dac5", 128, 1, 768, 128, 1, 1]) ==========
placeholder = PLACEHOLDER [128, 1, 768]
placeholder_red(ax0, ax1, ax2) += placeholder[ax0, ax1, k2]
T_divide(ax0, ax1, ax2) = (placeholder_red[ax0, ax1, ax2]/768f)

========== Task 8  (workload key: ["dde89265d3f1a59075cee648386eac1e", 128, 1, 768, 128, 1, 1, 128, 1, 1]) ==========
placeholder = PLACEHOLDER [128, 1, 768]
placeholder = PLACEHOLDER [128, 1, 1]
T_subtract(ax0, ax1, ax2) = (placeholder[ax0, ax1, ax2] - placeholder[ax0, ax1, 0])
T_multiply(ax0, ax1, ax2) = (T_subtract[ax0, ax1, ax2]*T_subtract[ax0, ax1, ax2])
T_multiply_red(ax0, ax1, ax2) += T_multiply[ax0, ax1, k2]
T_divide(ax0, ax1, ax2) = (T_multiply_red[ax0, ax1, ax2]/768f)

========== Task 9  (workload key: ["9e3bd222d4f8d250aeadf2fef0b15f2b", 1, 768, 768, 768, 768, 1, 768]) ==========
placeholder = PLACEHOLDER [1, 768]
placeholder = PLACEHOLDER [768, 768]
T_dense(i, j) += (placeholder[i, k]*placeholder[j, k])
placeholder = PLACEHOLDER [768]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + placeholder[ax1])
T_minimum(ax0, ax1) = min(T_add[ax0, ax1], 9f)
T_maximum(ax0, ax1) = max(T_minimum[ax0, ax1], -9f)
T_fast_tanh(ax0, ax1) = ((T_maximum[ax0, ax1]*(((T_maximum[ax0, ax1]*T_maximum[ax0, ax1])*(((T_maximum[ax0, ax1]*T_maximum[ax0,  ..(OMITTED).. *T_maximum[ax0, ax1])*(((T_maximum[ax0, ax1]*T_maximum[ax0, ax1])*1.19826e-06f) + 0.000118535f)) + 0.00226843f)) + 0.00489353f))

Sorry for my lots of questions. I’ll do my best to do more experiments and figure out the reasons why Pytorch AutoSchedule not working as MXNet and why TVM is not working as expected when sequence length increasing.

Thanks for the plentiful information.

For Q1, when you extract tasks with llvm -mcpu=skylake-avx512 -libs=cblas, some operators (i.e., dense) will be offloaded to cblas. It means those operators won’t be compiled by the TVM codegen, so AutoScheduler won’t see and tune them.

For Q2, the two differences you pointed out seem not really impactful. Maybe you can try to use debugger to compare the latency breakdown between two models: Debugger — tvm 0.8.dev0 documentation

@comaniac I follow your instructions to use debugger and compare the latency in the following IR graphs (latency > 100us with orange color). And I find maybe some operations are hard to tune, such as fused_nn_contrib_dense_pack. All my experiments are done with opt_level=3, required_pass=["FastMath"]

Experiment 1 - Compare BERT on MXNet (68.4ms) and Pytorch (122.9ms)

  • MXNet Debug IR Graph: drive link
  • Pytorch Debug IR Graph: drive link
  • From the above graphs, we can find MXNet use fused_nn_contrib_dense_pack_add while Pytorch use fused_nn_contrib_dense_pack operation. This is happened for all FC layers (4 in one transformer block) and I use the latency of first block as example.
    • FC for Query, Key, and Value (M: 503us, P: 1314us)
    • FC after self-attention (M: 202us, P: 651us)
    • FC after layer normalization (M: 798us, P: 2578us)
    • FC after GELU (M: 786us, P: 2572us)

Experiment 2 - Compare BERT on MXNet (68.4ms) and MXNet-tune (autoschedule) (15.4ms)

  • MXNet Debug IR Graph: drive link
  • MXNet-tune Debug IR Graph: drive link
  • From the above graphs, it is easy to find we can reduce most of dense and batch_matmul operations’ latency. I take the first transformer block as example.
    • FC for Query, Key, and Value (M: 503us, M-tune: 229us)
    • FC after self-attention (M: 202us, M-tune: 99us)
    • FC after layer normalization (M: 798us, M-tune: 292us)
    • FC after GELU (M: 786us, M-tune: 367us)
    • batch_matmul for Quert and Key (M: 1828us, M-tune:41us)
    • batch_matmul for Attention and Value (M: 1312us, M-tune:29us)

Experiment 3 - Compare BERT on Pytorch (122.9ms) and Pytorch-tune (autoschedule) (64.2ms)

  • Pytorch Debug IR Graph: drive link
  • Pytorch-tune Debug IR Graph: drive link
  • From the above graphs, we can find we only reduce the latency of batch_matmul but not dense operation. I take the first transformer block as example.
    • FC for Query, Key, and Value (P: 1314us, P-tune: 1276us)
    • FC after self-attention (P: 651us, P-tune: 403us)
    • FC after layer normalization (P: 2578us, P-tune: 1779us)
    • FC after GELU (P: 2572us, P-tune: 1684us)
    • batch_matmul for Quert and Key (P: 1624us, P-tune: 78us)
    • batch_matmul for Attention and Value (P: 1265us, P-tune: 70us)

Therefore, I suspect the problem is come from the fused_nn_contrib_dense_pack operation, but I don’t know why this is slower than fused_nn_contrib_dense_pack_add. Even the latter has additional add operation for adding bias. If you want the whole debug files (json and logs), I provide in this drive link.

@comaniac I change the TVM version with commit id: 91e07e1f3a7 (Feb. 5, 2021) which is the same as this repo. And the problem is solved because we will use fused_nn_batch_matmul for all FC (dense) layers rather than fused_nn_contrib_dense_pack. I think the problem is coming from this PR you provided, which also causes I cannot use -libs=mkl. I have the debug IR graphs in the following and now Pytorch can speed up from Pytorch script 26.63ms to 17.36ms after tuning with autoschedule.

1 Like

Hmm I don’t that PR is the root cause of using batch_matmul in BERT model tho. It might be due to this PR that uses dense instead of batch_matmul when the input shape is 2D and 3D:

Good work!

There’s a known issue that TVM’s dense op and batch_matmul op with Y = X * W^T does have bad performance in some models.

There’re several matmul & batch_matmul ops in bert that takes data tensor as both input and weight(exp. those in multi-head attentions) rather than use const parameter as weight. In such situation, we would see some explicit transpose inserted when the model are imported from TensorFlow or Pytorch(they use X * W for matmul by default). For the MXNet, as far as I know, it uses X * W^T by default.

The PR you found looks like creating a special schedule for dense + transpose, I’m not sure if that’s the key of the performance improving you got because it is written for AutoTVM and AutoScheduler will never use these manual schedule. You can have a more detailed analyse among those dense/batch_matmul ops’ layout and shape.

I would agree with @comaniac that the miss conversion from dense to batch_matmul caused some waste of computation before.

will auto-schedule work in case of “llvm -libs-cblas”? e.g. a network may contain matmul ops(leverage cblas) and other ops(leverage auto-schedule to get performance)?