Using Metal but TVM Asking for CUDA? USE_LLVM ON USE_METAL ON

If people could look into this it would be great.

My hope is to convert a Tensorflow graph to run in Metal on macOS.

Running the following command

python ~/dev/tvm_test/from_tensorflow_metal.py

Results in a Backtrace

Traceback (most recent call last):

  File "/Users/sam/dev/tvm_test/metal_tf_demo.py", line 75, in <module>
    params=params)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 251, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 120, in build
    self._build(mod, target, target_host)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x00000001159c6e5e tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*) + 14
  [bt] (7) 8   libtvm.dylib                        0x00000001159c8cf9 tvm::RelayExpr tvm::relay::MixedModeMutator::Rewrite<tvm::relay::CallNode>(tvm::relay::CallNode const*) + 57
  [bt] (6) 7   libtvm.dylib                        0x00000001159c7c82 tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&) + 2082
  [bt] (5) 6   libtvm.dylib                        0x00000001159c91fe tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void>&, tvm::runtime::ObjectRef>(tvm::relay::Call const&&&, tvm::Array<tvm::RelayExpr, void>&&&, tvm::runtime::ObjectRef&&) const + 254
  [bt] (4) 5   libtvm.dylib                        0x000000011596c9fd std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 109
  [bt] (3) 4   libtvm.dylib                        0x000000011596ca92 void tvm::runtime::detail::unpack_call_dispatcher<tvm::RelayExpr, 0, 3, tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::RelayExpr (* const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 82
  [bt] (2) 3   libtvm.dylib                        0x0000000115963b14 tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&) + 3364
  [bt] (1) 2   libtvm.dylib                        0x00000001159675a4 tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::__1::vector<tvm::RelayExpr, std::__1::allocator<tvm::RelayExpr> > const&) + 1284
  [bt] (0) 1   libtvm.dylib                        0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/nn/_nn.py", line 97, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-35>", line 2, in conv2d_alter_layout
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 267, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 183, in select_implementation
    all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 124, in get_valid_implementations
    strategy = fstrategy(attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 45, in __call__
    return _ffi_api.GenericFuncCallFunc(self, *args)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (5) 6   ???                                 0x00007ffee6c70d30 0x0 + 140732770225456
  [bt] (4) 5   _ctypes.cpython-37m-darwin.so       0x0000000109a5b36f ffi_call_unix64 + 79
  [bt] (3) 4   libtvm.dylib                        0x0000000115bfcbd6 TVMFuncCall + 70
  [bt] (2) 3   libtvm.dylib                        0x00000001155e2955 std::__1::__function::__func<tvm::$_5, std::__1::allocator<tvm::$_5>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 181
  [bt] (1) 2   libtvm.dylib                        0x00000001155e0687 tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const + 743
  [bt] (0) 1   libtvm.dylib                        0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/strategy/cuda.py", line 125, in conv2d_strategy_cuda
    if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 218, in compute_version
    self.device_type, self.device_id, 4)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 180, in _GetDeviceAttr
    device_type, device_id, attr_id)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (6) 7   ???                                 0x00007ffee6c6f640 0x0 + 140732770219584
  [bt] (5) 6   _ctypes.cpython-37m-darwin.so       0x0000000109a5b36f ffi_call_unix64 + 79
  [bt] (4) 5   libtvm.dylib                        0x0000000115bfcbd6 TVMFuncCall + 70
  [bt] (3) 4   libtvm.dylib                        0x0000000115bfecb0 std::__1::__function::__func<$_4, std::__1::allocator<$_4>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 400
  [bt] (2) 3   libtvm.dylib                        0x0000000115bfde14 tvm::runtime::DeviceAPIManager::GetAPI(int, bool) + 532
  [bt] (1) 2   libtvm.dylib                        0x0000000115bfe0a5 tvm::runtime::DeviceAPIManager::GetAPI(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, bool) + 421
  [bt] (0) 1   libtvm.dylib                        0x0000000115197829 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/sam/dev/github/tvm/src/runtime/c_runtime_api.cc", line 133
TVMError: Check failed: allow_missing: Device API gpu is not enabled.

What is concerning is that CUDA code is being called when tvm.gpu(0) and tvm.target.cuda() are not mentioned

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)

see

(venv) kaosnew:build sam$ grep -rn cuda ~/dev/tvm_test/metal_tf_demo.py
(venv) kaosnew:build sam$ grep -rn gpu ~/dev/tvm_test/metal_tf_demo.py

These results are not found

The contents of the script is as below

# tvm, relay
import tvm
from tvm import te
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow.compat.v1 as tf
#import tensorflow as tf
tf_compat_v1 = tf
tf.disable_v2_behavior()

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing

# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'

# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)

model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)

# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)

# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
label_map_url = os.path.join(repo_base, label_map)

# Target settings
target = 'metal'
target_host = 'llvm'
layout = None # "NCHW"
ctx = tvm.context(target, 1)
from tvm.contrib.download import download_testdata

img_path = download_testdata(image_url, img_name, module='data')
model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
label_path = download_testdata(label_map_url, label_map, module='data')

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

from PIL import Image
image = Image.open(img_path).resize((299, 299))

x = np.array(image)

shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)

print("Tensorflow protobuf imported to relay frontend.")
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)



from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))


def create_graph():
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
        graph_def = tf_compat_v1.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        # Call the utility to import the graph definition into default graph.
        graph_def = tf_testing.ProcessGraphDefParam(graph_def)

def run_inference_on_image(image):
    """Runs inference on an image.

    Parameters
    ----------
    image: String
        Image file name.

    Returns
    -------
        Nothing
    """
    if not tf_compat_v1.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()

    # Creates graph from saved GraphDef.
    create_graph()

    with tf_compat_v1.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})

        predictions = np.squeeze(predictions)

        # Creates node ID --> English string lookup.
        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                            uid_lookup_path=label_path)

        # Print top 5 predictions from tensorflow.
        top_k = predictions.argsort()[-5:][::-1]
        print ("===== TENSORFLOW RESULTS =======")
        for node_id in top_k:
            human_string = node_lookup.id_to_string(node_id)
            score = predictions[node_id]
            print('%s (score = %.5f)' % (human_string, score))

run_inference_on_image(img_path)

Is it that contrib.graph_runtime, has only ever been tested with CUDA and is making assumptions?

Here are the changes from the original found here: https://tvm.apache.org/docs/tutorials/frontend/from_tensorflow.html#sphx-glr-download-tutorials-frontend-from-tensorflow-py

(venv) kaosnew:build sam$ diff from_tensorflow_llvm.py from_tensorflow_metal.py
76c76
< target = 'llvm'
---
> target = 'metal'
79c79
< ctx = tvm.cpu(0)
---
> ctx = tvm.metal(0)