Run Tutorial Compile Tensorflow Models on GPU

Looking at the Compile Tensorflow models tutorial we can take a tensorflow model and compile and run it using TVM. The tutorial executers the model on the CPU.

When you try and run the model on the GPU, you get the following error: ValueError: not support this layout NHWC yet

I understand that this is because the model is in NHWC format, which is only supported by the x86 backend. The model needs to be in NCHW format to run on the GPU.

Is there an easy way to take this tutorial code or model and convert it so we can run the tensorflow model on a GPU using TVM?

I see similar query raised earlier too,
One option I was thinking is handling NHWC to NCHW conversion across the model at frontend.

I could give a patch over this weekend for this.

That would be great if you could do that. Thanks!

Would it change nnvm.frontend.from_tensorflow() to do the conversion? Similar to nnvm.frontend.from_keras()?

Yes, I will handle the conversion inside from_tensorflow with additional arg as format=‘NCHW’.

@BenTaylor3115

I tried an attempt to translate the entire model to NCHW at frontend to a level of input being in “NCHW”

InceptionV1 worked fine but later I tried V3 and mobilenet and observed.

  • It’s difficult to handle intermediate operations which don’t care about layout. i.e. adjusting the operator attributes for operators like reshape, squeeze …etc.
  • Hence the only option I could see is conversing only the operators (like convolution, pool …etc.) instead of entire model. This adds transpose operation before and after convolution which is an additional effort. But works well for any model.

I will try to share a patch soon for you to evaluate.

Thanks! Let me know when you have a patch, and I’ll take a look!

Hi! I also did some work in this direction. I’ve tried to export https://github.com/aymericdamien/TensorFlow-Examples from TF to TVM. Some model worked, but most had some operations missing. Here is the details (may be slightly out of date):

Model Missing operations
Fully Connected (all present)
Autoencoder (all present)
ConvolutionalNetwork Pack (https://github.com/dmlc/tvm/pull/1570)
Variational Autoencoder Exp, RandomStandardNormal, RealDiv, Tanh
RNN Fill, Split, Tanh, Unpack
Dynamic RNN Select, LessEqual, Max, Unpack, Split, Greater, Pack, Range, Switch, Min, Transpose, Merge, Tanh
Bidirectional RNN Split, Tanh, Unpack

That’s good info about the gap in tensor flow. Thanks.

@BenTaylor3115

Please try this patch. This include layout conversion for convolution and pooling.

Please share some bench marking details if possible.

Hi Siva,
Is this patch included in the latest tvm code? Or do I need to clone your forked repo to try out? I’m very interested in trying this. Thanks

Hi Siva,

I changed tvm/nnvm/python/nnvm/frontend/tensorflow.py according to your patch, uninstalled topi, nnvm, and tvm, and then installed topi, nnvm, and tvm based on the new code. I also changed the tutorial code to:

sym, params = nnvm.frontend.from_tensorflow(graph_def, layout="NCHW")

It runs on CPU. But when running on GPU, I got such error:

Traceback (most recent call last):
  File "from_tensorflow.py", line 122, in <module>
    graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
  File "/usr/local/lib/python2.7/dist-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 305, in build
    graph = graph.apply("GraphCompile")
  File "/usr/local/lib/python2.7/dist-packages/nnvm-0.8.0-py2.7.egg/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/usr/local/lib/python2.7/dist-packages/nnvm-0.8.0-py2.7.egg/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMCall CFunc Error:
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 55, in cfun
    rv = local_pyfunc(*pyargs)
  File "/usr/local/lib/python2.7/dist-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 115, in _lower
    raise RuntimeError(msg)
RuntimeError: Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 107, in _lower
    f = tvm.lower(sch, inputs, name=func_name)
  File "/usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/build_module.py", line 340, in lower
    bounds = schedule.InferBound(sch)
  File "/usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/_ffi/base.py", line 66, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
TVMError: [20:01:54] /tvm/src/schedule/message_passing.cc:36: Check failed: match iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x) domain already inferred, cannot prove their extents are the same 7 vs 8

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::StackTrace[abi:cxx11]()+0x5b) [0x7faa1f6ea8db]
[bt] (1) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7faa1f6eb128]
[bt] (2) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::Update(std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > >*, tvm::IterVar const&, tvm::Range)+0x330) [0x7faa1f9160c0]
[bt] (3) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::PassDownDomain(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > >*, bool)+0x47a) [0x7faa1f91662a]
[bt] (4) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::InferBound(tvm::Schedule const&)+0xed8) [0x7faa1f942c78]
[bt] (5) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(+0x813cf8) [0x7faa1f70fcf8]
[bt] (6) /usr/local/lib/python2.7/dist-packages/tvm-0.5.dev0-py2.7-linux-x86_64.egg/tvm/libtvm.so(TVMFuncCall+0x5e) [0x7faa1fafb00e]
[bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7faa605f3e40]
[bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7faa605f38ab]
[bt] (9) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7faa608033df]


Error during compile graph
--------------------------
Graph(%input0, %input1) {
  %input0, shape=[1,96,35,35]
  %input1, shape=[96,96,3,3]
  %2 = conv2d(%input0, %input1, kernel_size='(3, 3)', use_bias='False', channels='96', strides='(2L, 2L)', layout='NCHW', dilation='(1L, 1L)', kernel_layout='OIHW', padding='[0, 0]'), shape=[1,96,17,17]
  ret %2
}
graph_attr_keys = [shape, shape_num_unknown_nodes, dtype, dtype_num_unknown_nodes]

Could you help me?

@jjiang2cal thanks for your interest on this patch.
I will take a look into this issue early next week.

@jjiang2cal

I am a bit confused with above use case. Why the layout is NCHW in above graph. With tensorflow the layout should be still “NHWC”.

Trick here is passing extra argument layout=“NCHW” internally converts cuda incompatible operations (like convolution) to NCHW format.

I have updated above branch where all tensorflow front end test cases work fine on CUDA.

@srkreddy1238

Hi last time I replaced the files in the official tvm repo, reinstalled it, and ran the tutorial code.

Today I git clone --recursive your tvm repo, checked out your onnx branch, installed it, and ran your tutorials/nnvm/from_tensorflow.py. I got the following errors:

2018-09-13 18:00:33.770774: W tensorflow/core/framework/op_def_util.cc:346] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
DecodeJpeg: It's a pass through, please handle preprocessing before input
Tensorflow protobuf imported as nnvm graph
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[18:00:36] /tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
Traceback (most recent call last):
  File "from_tensorflow_layout.py", line 122, in <module>
    graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)
  File "/root/.local/lib/python2.7/site-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 304, in build
    graph = graph.apply("GraphCompile")
  File "/root/.local/lib/python2.7/site-packages/nnvm-0.8.0-py2.7.egg/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/root/.local/lib/python2.7/site-packages/nnvm-0.8.0-py2.7.egg/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMCall CFunc Error:
Traceback (most recent call last):
  File "/root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 54, in cfun
    rv = local_pyfunc(*pyargs)
  File "/root/.local/lib/python2.7/site-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 115, in _lower
    raise RuntimeError(msg)
RuntimeError: Traceback (most recent call last):
  File "/root/.local/lib/python2.7/site-packages/nnvm-0.8.0-py2.7.egg/nnvm/compiler/build_module.py", line 107, in _lower
    f = tvm.lower(sch, inputs, name=func_name)
  File "/root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/build_module.py", line 340, in lower
    bounds = schedule.InferBound(sch)
  File "/root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/_ffi/function.py", line 280, in my_api_func
    return flocal(*args)
  File "/root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 184, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/_ffi/base.py", line 66, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
TVMError: [18:00:37] /tvm/src/schedule/message_passing.cc:36: Check failed: match iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x) domain already inferred, cannot prove their extents are the same 7 vs 8

Stack trace returned 10 entries:
[bt] (0) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::StackTrace[abi:cxx11]()+0x5b) [0x7f691b643acb]
[bt] (1) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f691b644378]
[bt] (2) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::Update(std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > >*, tvm::IterVar const&, tvm::Range)+0x330) [0x7f691b865ed0]
[bt] (3) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::PassDownDomain(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > >*, bool)+0x47a) [0x7f691b86643a]
[bt] (4) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::InferBound(tvm::Schedule const&)+0xed8) [0x7f691b892a88]
[bt] (5) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(+0x809558) [0x7f691b663558]
[bt] (6) /root/.local/lib/python2.7/site-packages/tvm-0.4.0-py2.7-linux-x86_64.egg/tvm/libtvm.so(TVMFuncCall+0x5e) [0x7f691ba4ad2e]
[bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f695c53ee40]
[bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7f695c53e8ab]
[bt] (9) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7f695c74e3df]


Error during compile graph
--------------------------
Graph(%input0, %input1) {
  %input0, shape=[1,96,35,35]
  %input1, shape=[96,96,3,3]
  %2 = conv2d(%input0, %input1, kernel_size='(3, 3)', use_bias='False', channels='96', strides='(2L, 2L)', layout='NCHW', dilation='(1L, 1L)', kernel_layout='OIHW', padding='[0, 0]'), shape=[1,96,17,17]
  ret %2
}
graph_attr_keys = [shape, shape_num_unknown_nodes, dtype, dtype_num_unknown_nodes]

My docker image is from nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04, with python 2.7, tensorflow and tensorflow-gpu version 1.10.1.

Should I checkout your other branch? What should I do to run your tutorial from_tensorflow.py successfully?

Thanks a lot.

@jjiang2cal

Check this PR where the tensorflow models work fine on CUDA.

l2_normalization fail here and I am checking it, but the e2e test cases on Inception V1, V3 and mobilenet pass the test.

@srkreddy1238
The tutorial example with target=‘cuda’ now works for me! Thanks!

1 Like

Great.

Please share performance results if possible.