Here is my code.
import torch
import torchvision
import torch
import time
import numpy as np
import timeit
import argparse
import tvm
from tvm import relay
from tvm.contrib import graph_executor
# compile pytorch model to relay
def torch_to_relay(model, input_name, input_shape):
print("=== Compiling pytorch model to Relay... ===")
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(model, shape_list)
# print("mod ir: {}".format(mod.astext()))
print("=== Done ===\n")
return mod, params
# get llvm target
def get_llvm_target():
print("=== Generating tvm module directly ... ===")
target = tvm.target.Target(target='{"kind": "llvm"}')
print("target: {}".format(target))
print("=== Done ===\n")
return target
# get metal device target
def get_metal_target():
target = tvm.target.Target('{"kind": "metal", "max_num_threads": 1024, "thread_warp_size": 32}')
print("target: {}".format(target))
return target
# compile the model with relay by llvm
def compile_by_llvm(mod, params, target):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
dev = tvm.device(str(target), 0)
return lib, dev
# compile the model with relay by metal
def compile_by_metal(mod, params, target):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
dev = tvm.metal()
return lib, dev
# collect performance data
def collect_performance_data(module):
print("\n=== Collecting performance data ... ===")
timing_number = 10
timing_repeat = 10
report = (
np.array(timeit.Timer(lambda: module.run()).repeat(repeat=timing_repeat, number=timing_number)) * 1000 / timing_number
)
report = {
"mean": np.mean(report),
"median": np.median(report),
"std": np.std(report)
}
print("\n=== Done. ===")
return report
if __name__ == '__main__':
n = 1
c = 3
h = 224
w = 224
model = torchvision.models.resnet18()
example = torch.rand(n, c, h, w)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("m.pt")
input_name = "input"
input_shape = [n, c, h, w]
mod, params = torch_to_relay(traced_script_module, input_name, input_shape)
# compile by llvm ...
target = get_llvm_target()
lib, dev = compile_by_llvm(mod, params, target=target)
module = graph_executor.GraphModule(lib["default"](dev))
report_cpu = collect_performance_data(module=module)
print('cpu report: {}'.format(report_cpu))
# compile by metal ...
target = get_metal_target()
lib, dev = compile_by_metal(mod, params, target=target)
module = graph_executor.GraphModule(lib["default"](dev))
report_metal = collect_performance_data(module=module)
print('metal report: {}'.format(report_metal))
and the result shows as following:
cpu report: {'mean': 35.690075850000014, 'median': 35.350568800000026, 'std': 2.7024873898637622}
metal report: {'mean': 352.1809899699999, 'median': 361.5098249499997, 'std': 28.695996392881764}