If people could look into this it would be great.
My hope is to convert a Tensorflow graph to run in Metal on macOS.
If people could look into this it would be great.
My hope is to convert a Tensorflow graph to run in Metal on macOS.
Running the following command
python ~/dev/tvm_test/from_tensorflow_metal.py
Results in a Backtrace
Traceback (most recent call last):
File "/Users/sam/dev/tvm_test/metal_tf_demo.py", line 75, in <module>
params=params)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 251, in build
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 120, in build
self._build(mod, target, target_host)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x00000001159c6e5e tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*) + 14
[bt] (7) 8 libtvm.dylib 0x00000001159c8cf9 tvm::RelayExpr tvm::relay::MixedModeMutator::Rewrite<tvm::relay::CallNode>(tvm::relay::CallNode const*) + 57
[bt] (6) 7 libtvm.dylib 0x00000001159c7c82 tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&) + 2082
[bt] (5) 6 libtvm.dylib 0x00000001159c91fe tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void>&, tvm::runtime::ObjectRef>(tvm::relay::Call const&&&, tvm::Array<tvm::RelayExpr, void>&&&, tvm::runtime::ObjectRef&&) const + 254
[bt] (4) 5 libtvm.dylib 0x000000011596c9fd std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 109
[bt] (3) 4 libtvm.dylib 0x000000011596ca92 void tvm::runtime::detail::unpack_call_dispatcher<tvm::RelayExpr, 0, 3, tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::RelayExpr (* const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 82
[bt] (2) 3 libtvm.dylib 0x0000000115963b14 tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&) + 3364
[bt] (1) 2 libtvm.dylib 0x00000001159675a4 tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::__1::vector<tvm::RelayExpr, std::__1::allocator<tvm::RelayExpr> > const&) + 1284
[bt] (0) 1 libtvm.dylib 0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
rv = local_pyfunc(*pyargs)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/nn/_nn.py", line 97, in alter_op_layout_conv2d
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
File "<decorator-gen-35>", line 2, in conv2d_alter_layout
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 267, in dispatch_func
return dispatch_dict[k](*args, **kwargs)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 183, in select_implementation
all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 124, in get_valid_implementations
strategy = fstrategy(attrs, inputs, out_type, target)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 45, in __call__
return _ffi_api.GenericFuncCallFunc(self, *args)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
raise get_last_ffi_error()
[bt] (5) 6 ??? 0x00007ffee6c70d30 0x0 + 140732770225456
[bt] (4) 5 _ctypes.cpython-37m-darwin.so 0x0000000109a5b36f ffi_call_unix64 + 79
[bt] (3) 4 libtvm.dylib 0x0000000115bfcbd6 TVMFuncCall + 70
[bt] (2) 3 libtvm.dylib 0x00000001155e2955 std::__1::__function::__func<tvm::$_5, std::__1::allocator<tvm::$_5>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 181
[bt] (1) 2 libtvm.dylib 0x00000001155e0687 tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const + 743
[bt] (0) 1 libtvm.dylib 0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
rv = local_pyfunc(*pyargs)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/strategy/cuda.py", line 125, in conv2d_strategy_cuda
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 218, in compute_version
self.device_type, self.device_id, 4)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 180, in _GetDeviceAttr
device_type, device_id, attr_id)
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
raise get_last_ffi_error()
[bt] (6) 7 ??? 0x00007ffee6c6f640 0x0 + 140732770219584
[bt] (5) 6 _ctypes.cpython-37m-darwin.so 0x0000000109a5b36f ffi_call_unix64 + 79
[bt] (4) 5 libtvm.dylib 0x0000000115bfcbd6 TVMFuncCall + 70
[bt] (3) 4 libtvm.dylib 0x0000000115bfecb0 std::__1::__function::__func<$_4, std::__1::allocator<$_4>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 400
[bt] (2) 3 libtvm.dylib 0x0000000115bfde14 tvm::runtime::DeviceAPIManager::GetAPI(int, bool) + 532
[bt] (1) 2 libtvm.dylib 0x0000000115bfe0a5 tvm::runtime::DeviceAPIManager::GetAPI(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, bool) + 421
[bt] (0) 1 libtvm.dylib 0x0000000115197829 dmlc::LogMessageFatal::~LogMessageFatal() + 57
File "/Users/sam/dev/github/tvm/src/runtime/c_runtime_api.cc", line 133
TVMError: Check failed: allow_missing: Device API gpu is not enabled.
What is concerning is that CUDA code is being called when tvm.gpu(0)
and tvm.target.cuda()
are not mentioned
File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
see
(venv) kaosnew:build sam$ grep -rn cuda ~/dev/tvm_test/metal_tf_demo.py
(venv) kaosnew:build sam$ grep -rn gpu ~/dev/tvm_test/metal_tf_demo.py
These results are not found
The contents of the script is as below
# tvm, relay
import tvm
from tvm import te
from tvm import relay
# os and numpy
import numpy as np
import os.path
# Tensorflow imports
import tensorflow.compat.v1 as tf
#import tensorflow as tf
tf_compat_v1 = tf
tf.disable_v2_behavior()
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)
# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
label_map_url = os.path.join(repo_base, label_map)
# Target settings
target = 'metal'
target_host = 'llvm'
layout = None # "NCHW"
ctx = tvm.context(target, 1)
from tvm.contrib.download import download_testdata
img_path = download_testdata(image_url, img_name, module='data')
model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
label_path = download_testdata(label_map_url, label_map, module='data')
with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf_compat_v1.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
from PIL import Image
image = Image.open(img_path).resize((299, 299))
x = np.array(image)
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict)
print("Tensorflow protobuf imported to relay frontend.")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod,
target=target,
target_host=target_host,
params=params)
from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
uid_lookup_path=label_path)
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
def run_inference_on_image(image):
"""Runs inference on an image.
Parameters
----------
image: String
Image file name.
Returns
-------
Nothing
"""
if not tf_compat_v1.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf_compat_v1.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
uid_lookup_path=label_path)
# Print top 5 predictions from tensorflow.
top_k = predictions.argsort()[-5:][::-1]
print ("===== TENSORFLOW RESULTS =======")
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
run_inference_on_image(img_path)
Is it that contrib.graph_runtime
, has only ever been tested with CUDA and is making assumptions?
Here are the changes from the original found here: https://tvm.apache.org/docs/tutorials/frontend/from_tensorflow.html#sphx-glr-download-tutorials-frontend-from-tensorflow-py
(venv) kaosnew:build sam$ diff from_tensorflow_llvm.py from_tensorflow_metal.py
76c76
< target = 'llvm'
---
> target = 'metal'
79c79
< ctx = tvm.cpu(0)
---
> ctx = tvm.metal(0)