First tutorial fails with Metal

Using the following script

	# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _tutorial-relay-quick-start:

Quick Start Tutorial for Compiling Deep Learning Models
=======================================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_, `Truman Tian <https://github.com/SiNZeRo>`_

This example shows how to build a neural network with Relay python frontend and
generates a runtime library for Nvidia GPU with TVM.
Notice that you need to build TVM with cuda and llvm enabled.
"""

######################################################################
# Overview for Supported Hardware Backend of TVM
# ----------------------------------------------
# The image below shows hardware backend currently supported by TVM:
#
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tvm_support_list.png
#      :align: center
#
# In this tutorial, we'll choose cuda and llvm as target backends.
# To begin with, let's import Relay and TVM.

import numpy as np

from tvm import relay
from tvm.relay import testing
import tvm
from tvm import te
from tvm.contrib import graph_runtime

######################################################################
# Define Neural Network in Relay
# ------------------------------
# First, let's define a neural network with relay python frontend.
# For simplicity, we'll use pre-defined resnet-18 network in Relay.
# Parameters are initialized with Xavier initializer.
# Relay also supports other model formats such as MXNet, CoreML, ONNX and
# Tensorflow.
#
# In this tutorial, we assume we will do inference on our device
# and the batch size is set to be 1. Input images are RGB color
# images of size 224 * 224. We can call the :any:`tvm.relay.TupleWrapper.astext()`
# to show the network structure.

batch_size = 1
num_class = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_class)

mod, params = relay.testing.resnet.get_workload(
    num_layers=18, batch_size=batch_size, image_shape=image_shape)

# set show_meta_data=True if you want to show meta data
print(mod.astext(show_meta_data=False))

######################################################################
# Compilation
# -----------
# Next step is to compile the model using the Relay/TVM pipeline.
# Users can specify the optimization level of the compilation.
# Currently this value can be 0 to 3. The optimization passes include
# operator fusion, pre-computation, layout transformation and so on.
#
# :py:func:`relay.build` returns three components: the execution graph in
# json format, the TVM module library of compiled functions specifically
# for this graph on the target hardware, and the parameter blobs of
# the model. During the compilation, Relay does the graph-level
# optimization while TVM does the tensor-level optimization, resulting
# in an optimized runtime module for model serving.
#
# We'll first compile for Nvidia GPU. Behind the scene, :py:func:`relay.build`
# first does a number of graph-level optimizations, e.g. pruning, fusing, etc.,
# then registers the operators (i.e. the nodes of the optimized graphs) to
# TVM implementations to generate a `tvm.module`.
# To generate the module library, TVM will first transfer the high level IR
# into the lower intrinsic IR of the specified target backend, which is CUDA
# in this example. Then the machine code will be generated as the module library.

opt_level = 3
target = tvm.target.create("llvm")
with relay.build_config(opt_level=opt_level):
    graph, lib, params = relay.build(mod, target, params=params)

#####################################################################
# Run the generate library
# ------------------------
# Now we can create graph runtime and run the module on Nvidia GPU.

# create random input
ctx = tvm.metal()
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
# create module
module = graph_runtime.create(graph, lib, ctx)
# set input and parameters
module.set_input("data", data)
module.set_input(**params)
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()

# Print first 10 elements of output
print(out.flatten()[0:10])

######################################################################
# Save and Load Compiled Module
# -----------------------------
# We can also save the graph, lib and parameters into files and load them
# back in deploy environment.

####################################################

# save the graph, lib and params into separate files
from tvm.contrib import util

temp = util.tempdir()
path_lib = temp.relpath("deploy_lib.tar")
lib.export_library(path_lib)
with open(temp.relpath("deploy_graph.json"), "w") as fo:
    fo.write(graph)
with open(temp.relpath("deploy_param.params"), "wb") as fo:
    fo.write(relay.save_param_dict(params))
print(temp.listdir())

####################################################

# load the module back.
loaded_json = open(temp.relpath("deploy_graph.json")).read()
loaded_lib = tvm.runtime.load_module(path_lib)
loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))

module = graph_runtime.create(loaded_json, loaded_lib, ctx)
module.load_params(loaded_params)
module.run(data=input_data)
out_deploy = module.get_output(0).asnumpy()

# Print first 10 elements of output
print(out_deploy.flatten()[0:10])

# check whether the output from deployed module is consistent with original one
tvm.testing.assert_allclose(out_deploy, out, atol=1e-3)

I get the following result

  (venv) kaosnew:tvm sam$ python ~/Downloads/relay_quick_start.py 

v0.0.4 def @main(%data: Tensor[(1, 3, 224, 224), float32], %bn_data_gamma: Tensor[(3), float32], %bn_data_beta: Tensor[(3), float32], %bn_data_moving_mean: Tensor[(3), float32], %bn_data_moving_var: Tensor[(3), float32], %conv0_weight: Tensor[(64, 3, 7, 7), float32], %bn0_gamma: Tensor[(64), float32], %bn0_beta: Tensor[(64), float32], %bn0_moving_mean: Tensor[(64), float32], %bn0_moving_var: Tensor[(64), float32], %stage1_unit1_bn1_gamma: Tensor[(64), float32], %stage1_unit1_bn1_beta: Tensor[(64), float32], %stage1_unit1_bn1_moving_mean: Tensor[(64), float32], %stage1_unit1_bn1_moving_var: Tensor[(64), float32], %stage1_unit1_conv1_weight: Tensor[(64, 64, 3, 3), float32], %stage1_unit1_bn2_gamma: Tensor[(64), float32], %stage1_unit1_bn2_beta: Tensor[(64), float32], %stage1_unit1_bn2_moving_mean: Tensor[(64), float32], %stage1_unit1_bn2_moving_var: Tensor[(64), float32], %stage1_unit1_conv2_weight: Tensor[(64, 64, 3, 3), float32], %stage1_unit1_sc_weight: Tensor[(64, 64, 1, 1), float32], %stage1_unit2_bn1_gamma: Tensor[(64), float32], %stage1_unit2_bn1_beta: Tensor[(64), float32], %stage1_unit2_bn1_moving_mean: Tensor[(64), float32], %stage1_unit2_bn1_moving_var: Tensor[(64), float32], %stage1_unit2_conv1_weight: Tensor[(64, 64, 3, 3), float32], %stage1_unit2_bn2_gamma: Tensor[(64), float32], %stage1_unit2_bn2_beta: Tensor[(64), float32], %stage1_unit2_bn2_moving_mean: Tensor[(64), float32], %stage1_unit2_bn2_moving_var: Tensor[(64), float32], %stage1_unit2_conv2_weight: Tensor[(64, 64, 3, 3), float32], %stage2_unit1_bn1_gamma: Tensor[(64), float32], %stage2_unit1_bn1_beta: Tensor[(64), float32], %stage2_unit1_bn1_moving_mean: Tensor[(64), float32], %stage2_unit1_bn1_moving_var: Tensor[(64), float32], %stage2_unit1_conv1_weight: Tensor[(128, 64, 3, 3), float32], %stage2_unit1_bn2_gamma: Tensor[(128), float32], %stage2_unit1_bn2_beta: Tensor[(128), float32], %stage2_unit1_bn2_moving_mean: Tensor[(128), float32], %stage2_unit1_bn2_moving_var: Tensor[(128), float32], %stage2_unit1_conv2_weight: Tensor[(128, 128, 3, 3), float32], %stage2_unit1_sc_weight: Tensor[(128, 64, 1, 1), float32], %stage2_unit2_bn1_gamma: Tensor[(128), float32], %stage2_unit2_bn1_beta: Tensor[(128), float32], %stage2_unit2_bn1_moving_mean: Tensor[(128), float32], %stage2_unit2_bn1_moving_var: Tensor[(128), float32], %stage2_unit2_conv1_weight: Tensor[(128, 128, 3, 3), float32], %stage2_unit2_bn2_gamma: Tensor[(128), float32], %stage2_unit2_bn2_beta: Tensor[(128), float32], %stage2_unit2_bn2_moving_mean: Tensor[(128), float32], %stage2_unit2_bn2_moving_var: Tensor[(128), float32], %stage2_unit2_conv2_weight: Tensor[(128, 128, 3, 3), float32], %stage3_unit1_bn1_gamma: Tensor[(128), float32], %stage3_unit1_bn1_beta: Tensor[(128), float32], %stage3_unit1_bn1_moving_mean: Tensor[(128), float32], %stage3_unit1_bn1_moving_var: Tensor[(128), float32], %stage3_unit1_conv1_weight: Tensor[(256, 128, 3, 3), float32], %stage3_unit1_bn2_gamma: Tensor[(256), float32], %stage3_unit1_bn2_beta: Tensor[(256), float32], %stage3_unit1_bn2_moving_mean: Tensor[(256), float32], %stage3_unit1_bn2_moving_var: Tensor[(256), float32], %stage3_unit1_conv2_weight: Tensor[(256, 256, 3, 3), float32], %stage3_unit1_sc_weight: Tensor[(256, 128, 1, 1), float32], %stage3_unit2_bn1_gamma: Tensor[(256), float32], %stage3_unit2_bn1_beta: Tensor[(256), float32], %stage3_unit2_bn1_moving_mean: Tensor[(256), float32], %stage3_unit2_bn1_moving_var: Tensor[(256), float32], %stage3_unit2_conv1_weight: Tensor[(256, 256, 3, 3), float32], %stage3_unit2_bn2_gamma: Tensor[(256), float32], %stage3_unit2_bn2_beta: Tensor[(256), float32], %stage3_unit2_bn2_moving_mean: Tensor[(256), float32], %stage3_unit2_bn2_moving_var: Tensor[(256), float32], %stage3_unit2_conv2_weight: Tensor[(256, 256, 3, 3), float32], %stage4_unit1_bn1_gamma: Tensor[(256), float32], %stage4_unit1_bn1_beta: Tensor[(256), float32], %stage4_unit1_bn1_moving_mean: Tensor[(256), float32], %stage4_unit1_bn1_moving_var: Tensor[(256), float32], %stage4_unit1_conv1_weight: Tensor[(512, 256, 3, 3), float32], %stage4_unit1_bn2_gamma: Tensor[(512), float32], %stage4_unit1_bn2_beta: Tensor[(512), float32], %stage4_unit1_bn2_moving_mean: Tensor[(512), float32], %stage4_unit1_bn2_moving_var: Tensor[(512), float32], %stage4_unit1_conv2_weight: Tensor[(512, 512, 3, 3), float32], %stage4_unit1_sc_weight: Tensor[(512, 256, 1, 1), float32], %stage4_unit2_bn1_gamma: Tensor[(512), float32], %stage4_unit2_bn1_beta: Tensor[(512), float32], %stage4_unit2_bn1_moving_mean: Tensor[(512), float32], %stage4_unit2_bn1_moving_var: Tensor[(512), float32], %stage4_unit2_conv1_weight: Tensor[(512, 512, 3, 3), float32], %stage4_unit2_bn2_gamma: Tensor[(512), float32], %stage4_unit2_bn2_beta: Tensor[(512), float32], %stage4_unit2_bn2_moving_mean: Tensor[(512), float32], %stage4_unit2_bn2_moving_var: Tensor[(512), float32], %stage4_unit2_conv2_weight: Tensor[(512, 512, 3, 3), float32], %bn1_gamma: Tensor[(512), float32], %bn1_beta: Tensor[(512), float32], %bn1_moving_mean: Tensor[(512), float32], %bn1_moving_var: Tensor[(512), float32], %fc1_weight: Tensor[(1000, 512), float32], %fc1_bias: Tensor[(1000), float32]) -> Tensor[(1, 1000), float32] { %0 = nn.batch_norm(%data, %bn_data_gamma, %bn_data_beta, %bn_data_moving_mean, %bn_data_moving_var, epsilon=2e-05f, scale=False) /* ty=(Tensor[(1, 3, 224, 224), float32], Tensor[(3), float32], Tensor[(3), float32]) /; %1 = %0.0; %2 = nn.conv2d(%1, %conv0_weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) / ty=Tensor[(1, 64, 112, 112), float32] /; %3 = nn.batch_norm(%2, %bn0_gamma, %bn0_beta, %bn0_moving_mean, %bn0_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %4 = %3.0; %5 = nn.relu(%4) / ty=Tensor[(1, 64, 112, 112), float32] /; %6 = nn.max_pool2d(%5, pool_size=[3, 3], strides=[2, 2], padding=[1, 1]) / ty=Tensor[(1, 64, 56, 56), float32] /; %7 = nn.batch_norm(%6, %stage1_unit1_bn1_gamma, %stage1_unit1_bn1_beta, %stage1_unit1_bn1_moving_mean, %stage1_unit1_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %8 = %7.0; %9 = nn.relu(%8) / ty=Tensor[(1, 64, 56, 56), float32] /; %10 = nn.conv2d(%9, %stage1_unit1_conv1_weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) / ty=Tensor[(1, 64, 56, 56), float32] /; %11 = nn.batch_norm(%10, %stage1_unit1_bn2_gamma, %stage1_unit1_bn2_beta, %stage1_unit1_bn2_moving_mean, %stage1_unit1_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %12 = %11.0; %13 = nn.relu(%12) / ty=Tensor[(1, 64, 56, 56), float32] /; %14 = nn.conv2d(%13, %stage1_unit1_conv2_weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) / ty=Tensor[(1, 64, 56, 56), float32] /; %15 = nn.conv2d(%9, %stage1_unit1_sc_weight, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) / ty=Tensor[(1, 64, 56, 56), float32] /; %16 = add(%14, %15) / ty=Tensor[(1, 64, 56, 56), float32] /; %17 = nn.batch_norm(%16, %stage1_unit2_bn1_gamma, %stage1_unit2_bn1_beta, %stage1_unit2_bn1_moving_mean, %stage1_unit2_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %18 = %17.0; %19 = nn.relu(%18) / ty=Tensor[(1, 64, 56, 56), float32] /; %20 = nn.conv2d(%19, %stage1_unit2_conv1_weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) / ty=Tensor[(1, 64, 56, 56), float32] /; %21 = nn.batch_norm(%20, %stage1_unit2_bn2_gamma, %stage1_unit2_bn2_beta, %stage1_unit2_bn2_moving_mean, %stage1_unit2_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %22 = %21.0; %23 = nn.relu(%22) / ty=Tensor[(1, 64, 56, 56), float32] /; %24 = nn.conv2d(%23, %stage1_unit2_conv2_weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) / ty=Tensor[(1, 64, 56, 56), float32] /; %25 = add(%24, %16) / ty=Tensor[(1, 64, 56, 56), float32] /; %26 = nn.batch_norm(%25, %stage2_unit1_bn1_gamma, %stage2_unit1_bn1_beta, %stage2_unit1_bn1_moving_mean, %stage2_unit1_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) /; %27 = %26.0; %28 = nn.relu(%27) / ty=Tensor[(1, 64, 56, 56), float32] /; %29 = nn.conv2d(%28, %stage2_unit1_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) / ty=Tensor[(1, 128, 28, 28), float32] /; %30 = nn.batch_norm(%29, %stage2_unit1_bn2_gamma, %stage2_unit1_bn2_beta, %stage2_unit1_bn2_moving_mean, %stage2_unit1_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 128, 28, 28), float32], Tensor[(128), float32], Tensor[(128), float32]) /; %31 = %30.0; %32 = nn.relu(%31) / ty=Tensor[(1, 128, 28, 28), float32] /; %33 = nn.conv2d(%32, %stage2_unit1_conv2_weight, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) / ty=Tensor[(1, 128, 28, 28), float32] /; %34 = nn.conv2d(%28, %stage2_unit1_sc_weight, strides=[2, 2], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) / ty=Tensor[(1, 128, 28, 28), float32] /; %35 = add(%33, %34) / ty=Tensor[(1, 128, 28, 28), float32] /; %36 = nn.batch_norm(%35, %stage2_unit2_bn1_gamma, %stage2_unit2_bn1_beta, %stage2_unit2_bn1_moving_mean, %stage2_unit2_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 128, 28, 28), float32], Tensor[(128), float32], Tensor[(128), float32]) /; %37 = %36.0; %38 = nn.relu(%37) / ty=Tensor[(1, 128, 28, 28), float32] /; %39 = nn.conv2d(%38, %stage2_unit2_conv1_weight, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) / ty=Tensor[(1, 128, 28, 28), float32] /; %40 = nn.batch_norm(%39, %stage2_unit2_bn2_gamma, %stage2_unit2_bn2_beta, %stage2_unit2_bn2_moving_mean, %stage2_unit2_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 128, 28, 28), float32], Tensor[(128), float32], Tensor[(128), float32]) /; %41 = %40.0; %42 = nn.relu(%41) / ty=Tensor[(1, 128, 28, 28), float32] /; %43 = nn.conv2d(%42, %stage2_unit2_conv2_weight, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3]) / ty=Tensor[(1, 128, 28, 28), float32] /; %44 = add(%43, %35) / ty=Tensor[(1, 128, 28, 28), float32] /; %45 = nn.batch_norm(%44, %stage3_unit1_bn1_gamma, %stage3_unit1_bn1_beta, %stage3_unit1_bn1_moving_mean, %stage3_unit1_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 128, 28, 28), float32], Tensor[(128), float32], Tensor[(128), float32]) /; %46 = %45.0; %47 = nn.relu(%46) / ty=Tensor[(1, 128, 28, 28), float32] /; %48 = nn.conv2d(%47, %stage3_unit1_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) / ty=Tensor[(1, 256, 14, 14), float32] /; %49 = nn.batch_norm(%48, %stage3_unit1_bn2_gamma, %stage3_unit1_bn2_beta, %stage3_unit1_bn2_moving_mean, %stage3_unit1_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 256, 14, 14), float32], Tensor[(256), float32], Tensor[(256), float32]) /; %50 = %49.0; %51 = nn.relu(%50) / ty=Tensor[(1, 256, 14, 14), float32] /; %52 = nn.conv2d(%51, %stage3_unit1_conv2_weight, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) / ty=Tensor[(1, 256, 14, 14), float32] /; %53 = nn.conv2d(%47, %stage3_unit1_sc_weight, strides=[2, 2], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) / ty=Tensor[(1, 256, 14, 14), float32] /; %54 = add(%52, %53) / ty=Tensor[(1, 256, 14, 14), float32] /; %55 = nn.batch_norm(%54, %stage3_unit2_bn1_gamma, %stage3_unit2_bn1_beta, %stage3_unit2_bn1_moving_mean, %stage3_unit2_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 256, 14, 14), float32], Tensor[(256), float32], Tensor[(256), float32]) /; %56 = %55.0; %57 = nn.relu(%56) / ty=Tensor[(1, 256, 14, 14), float32] /; %58 = nn.conv2d(%57, %stage3_unit2_conv1_weight, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) / ty=Tensor[(1, 256, 14, 14), float32] /; %59 = nn.batch_norm(%58, %stage3_unit2_bn2_gamma, %stage3_unit2_bn2_beta, %stage3_unit2_bn2_moving_mean, %stage3_unit2_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 256, 14, 14), float32], Tensor[(256), float32], Tensor[(256), float32]) /; %60 = %59.0; %61 = nn.relu(%60) / ty=Tensor[(1, 256, 14, 14), float32] /; %62 = nn.conv2d(%61, %stage3_unit2_conv2_weight, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3]) / ty=Tensor[(1, 256, 14, 14), float32] /; %63 = add(%62, %54) / ty=Tensor[(1, 256, 14, 14), float32] /; %64 = nn.batch_norm(%63, %stage4_unit1_bn1_gamma, %stage4_unit1_bn1_beta, %stage4_unit1_bn1_moving_mean, %stage4_unit1_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 256, 14, 14), float32], Tensor[(256), float32], Tensor[(256), float32]) /; %65 = %64.0; %66 = nn.relu(%65) / ty=Tensor[(1, 256, 14, 14), float32] /; %67 = nn.conv2d(%66, %stage4_unit1_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) / ty=Tensor[(1, 512, 7, 7), float32] /; %68 = nn.batch_norm(%67, %stage4_unit1_bn2_gamma, %stage4_unit1_bn2_beta, %stage4_unit1_bn2_moving_mean, %stage4_unit1_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 512, 7, 7), float32], Tensor[(512), float32], Tensor[(512), float32]) /; %69 = %68.0; %70 = nn.relu(%69) / ty=Tensor[(1, 512, 7, 7), float32] /; %71 = nn.conv2d(%70, %stage4_unit1_conv2_weight, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) / ty=Tensor[(1, 512, 7, 7), float32] /; %72 = nn.conv2d(%66, %stage4_unit1_sc_weight, strides=[2, 2], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]) / ty=Tensor[(1, 512, 7, 7), float32] /; %73 = add(%71, %72) / ty=Tensor[(1, 512, 7, 7), float32] /; %74 = nn.batch_norm(%73, %stage4_unit2_bn1_gamma, %stage4_unit2_bn1_beta, %stage4_unit2_bn1_moving_mean, %stage4_unit2_bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 512, 7, 7), float32], Tensor[(512), float32], Tensor[(512), float32]) /; %75 = %74.0; %76 = nn.relu(%75) / ty=Tensor[(1, 512, 7, 7), float32] /; %77 = nn.conv2d(%76, %stage4_unit2_conv1_weight, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) / ty=Tensor[(1, 512, 7, 7), float32] /; %78 = nn.batch_norm(%77, %stage4_unit2_bn2_gamma, %stage4_unit2_bn2_beta, %stage4_unit2_bn2_moving_mean, %stage4_unit2_bn2_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 512, 7, 7), float32], Tensor[(512), float32], Tensor[(512), float32]) /; %79 = %78.0; %80 = nn.relu(%79) / ty=Tensor[(1, 512, 7, 7), float32] /; %81 = nn.conv2d(%80, %stage4_unit2_conv2_weight, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3]) / ty=Tensor[(1, 512, 7, 7), float32] /; %82 = add(%81, %73) / ty=Tensor[(1, 512, 7, 7), float32] /; %83 = nn.batch_norm(%82, %bn1_gamma, %bn1_beta, %bn1_moving_mean, %bn1_moving_var, epsilon=2e-05f) / ty=(Tensor[(1, 512, 7, 7), float32], Tensor[(512), float32], Tensor[(512), float32]) /; %84 = %83.0; %85 = nn.relu(%84) / ty=Tensor[(1, 512, 7, 7), float32] /; %86 = nn.global_avg_pool2d(%85) / ty=Tensor[(1, 512, 1, 1), float32] /; %87 = nn.batch_flatten(%86) / ty=Tensor[(1, 512), float32] /; %88 = nn.dense(%87, %fc1_weight, units=1000) / ty=Tensor[(1, 1000), float32] /; %89 = nn.bias_add(%88, %fc1_bias, axis=-1) / ty=Tensor[(1, 1000), float32] /; nn.softmax(%89) / ty=Tensor[(1, 1000), float32] */ }

Cannot find config for target=llvm, workload=(‘dense_nopack.x86’, (‘TENSOR’, (1, 512), ‘float32’), (‘TENSOR’, (1000, 512), ‘float32’), None, ‘float32’). A fallback configuration is used, which may bring great performance regression. [17:18:21] /Users/sam/dev/github/tvm/src/runtime/metal/metal_device_api.mm:136: Intializing Metal device 0, name=Intel® HD Graphics Unknown [17:18:22] /Users/sam/dev/github/tvm/src/runtime/metal/metal_device_api.mm:136: Intializing Metal device 1, name=AMD Radeon Pro 560 Traceback (most recent call last):

File "/Users/sam/Downloads/relay_quick_start.py", line 116, in <module>
  module.run()

File "/Users/sam/dev/github/tvm/build/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/contrib/graph_runtime.py", line 176, in run
  self._run()

File "/Users/sam/dev/github/tvm/build/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
  raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last): [bt] (6) 7 ??? 0x00007ffee578b440 0x0 + 140732748313664 [bt] (5) 6 _ctypes.cpython-37m-darwin.so 0x00000001153fe36f ffi_call_unix64 + 79 [bt] (4) 5 libtvm.dylib 0x00000001168727e6 TVMFuncCall + 70 [bt] (3) 4 libtvm.dylib 0x00000001168cd38f std::__1::__function::__func<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&, tvm::runtime::ObjectPtrtvm::runtime::Object const&)::$_9, std::__1::allocator<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&, tvm::runtime::ObjectPtrtvm::runtime::Object const&)::$_9>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 79 [bt] (2) 3 libtvm.dylib 0x00000001168ca871 std::__1::__function::__func<tvm::runtime::GraphRuntime::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::__1::vector<DLTensor, std::__1::allocator > const&, unsigned long)::$_2, std::__1::allocator<tvm::runtime::GraphRuntime::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::__1::vector<DLTensor, std::__1::allocator > const&, unsigned long)::$_2>, void ()>::operator()() + 81 [bt] (1) 2 libtvm.dylib 0x000000011687c5e9 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int ()(TVMValue, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtrtvm::runtime::Object const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int ()(TVMValue, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtrtvm::runtime::Object const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 313 [bt] (0) 1 libtvm.dylib 0x0000000115e0e4f9 dmlc::LogMessageFatal::~LogMessageFatal() + 57 File “/Users/sam/dev/github/tvm/src/runtime/library_module.cc”, line 89 TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (dev_type == 1), device_type need to be 1

There is mention of the Metal devices, which is further than I have travelled before.

How do you specify a Metal target properly the docs seem lacking.