Tune graph torch quantized model not find input node

when test auto tuning for quantized model,it‘s not find input node?

import os
import numpy as np
import torch
import tvm
from tvm import relay, autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_executor as runtime

from PIL import Image
import numpy as np
import time
import torch
import tvm.relay.testing
from torchvision.models.quantization import mobilenet as qmobilenet
from tvm import auto_scheduler
from tvm.contrib import graph_executor
# Replace "llvm" with the correct target of your CPU.
# For example, for AWS EC2 c5 instance with Intel Xeon
# Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512".
# For AWS EC2 c4 instance with Intel Xeon E5-2666 v3, it should be
# "llvm -mcpu=core-avx2".
target = "llvm -mcpu=skylake-avx512"

batch_size = 1
dtype = "float32"
model_name = "mobilenet"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name

# Set the input name of the graph
# For ONNX models, it is typically "0".
input_name = "input"
use_sparse = False

# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)

tuning_option = {
    "log_filename": log_file,
    "tuner": "random",
    "early_stopping": None,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(
            number=1, repeat=10, min_repeat_ms=0, enable_cpu_cache_flush=True
        ),
    ),
}


def get_transform():
    import torchvision.transforms as transforms

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]
    )


def get_real_image(im_height, im_width):
    # img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
    # img_path = download_testdata(img_url, "cat.png", module="data")
    img_path = "/home/user/project/mmdetection/projects/pruner/tvm/cat.png"
    return Image.open(img_path).resize((im_height, im_width))


def get_imagenet_input():
    im = get_real_image(224, 224)
    preprocess = get_transform()
    pt_tensor = preprocess(im)
    return np.expand_dims(pt_tensor.numpy(), 0)


def get_synset():
    # synset_url = "".join(
    #     [
    #         "https://gist.githubusercontent.com/zhreshold/",
    #         "4d0b62f3d01426887599d4f7ede23ee5/raw/",
    #         "596b27d23537e5a1b5751d2b0481ef172f58b539/",
    #         "imagenet1000_clsid_to_human.txt",
    #     ]
    # )
    # # return synset_url
    # synset_name = "imagenet1000_clsid_to_human.txt"
    # synset_path = download_testdata(synset_url, synset_name, module="data")
    synset_path = "/home/user/tvm/imagenet1000_clsid_to_human.txt"
    with open(synset_path) as f:
        return eval(f.read())


def quantize_model(model, inp):
    model.fuse_model()
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)
    # Dummy calibration
    model(inp)
    torch.quantization.convert(model, inplace=True)


def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shapes = [("input", (batch_size, 3, 224, 224))]
    output_shape = (batch_size, 1000)

    inp = get_imagenet_input()
    qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()
    pt_inp = torch.from_numpy(inp)
    quantize_model(qmodel, pt_inp)
    script_module = torch.jit.trace(qmodel, pt_inp).eval()

    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

    return mod, params, (1, 3, 224, 224), output_shape

# You can skip the implementation of this function for this tutorial.
def tune_kernels(
    tasks, measure_option, tuner="gridsearch", early_stopping=None, log_filename="tuning.log"
):

    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))

        # create tuner
        if tuner == "xgb" or tuner == "xgb-rank":
            tuner_obj = XGBTuner(task, loss_type="rank")
        elif tuner == "ga":
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == "random":
            tuner_obj = RandomTuner(task)
        elif tuner == "gridsearch":
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        # do tuning
        n_trial = len(task.config_space)
        tuner_obj.tune(
            n_trial=n_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                autotvm.callback.progress_bar(n_trial, prefix=prefix),
                autotvm.callback.log_to_file(log_filename),
            ],
        )


# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
    target_op = [
        relay.op.get("nn.conv2d"),
    ]
    print(target_op)
    Tuner = DPTuner if use_DP else PBQPTuner
    # import pdb; pdb.set_trace()
    executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=2000)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)


def evaluate_performance(lib, data_shape):
    # upload parameters to device
    dev = tvm.cpu()
    data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
    module = runtime.GraphModule(lib["default"](dev))
    module.set_input(input_name, data_tvm)

    # evaluate
    print("Evaluate inference time cost...")
    print(module.benchmark(dev, number=100, repeat=3))


def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, data_shape, out_shape = get_network(model_name, batch_size)

    tasks = autotvm.task.extract_from_program(
        mod["main"], target=target, params=params,
        ops=(relay.op.get("nn.conv2d"),))
    # run tuning tasks
    tune_kernels(tasks, **tuning_opt)
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

    # compile kernels in default mode
    print("Evaluation of the network compiled in 'default' mode without auto tune:")
    with tvm.transform.PassContext(opt_level=3):
        print("Compile...")
        lib = relay.build(mod, target=target, params=params)
        evaluate_performance(lib, data_shape)

    # compile kernels in kernel tuned only mode
    print("\nEvaluation of the network been tuned on kernel level:")
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build(mod, target=target, params=params)
        evaluate_performance(lib, data_shape)

    # compile kernels with graph-level best records
    print("\nEvaluation of the network been tuned on graph level:")
    with autotvm.apply_graph_best(graph_opt_sch_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build_module.build(mod, target=target, params=params)
        evaluate_performance(lib, data_shape)

# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.


import os
os.environ['OPENBLAS_NUM_THREADS'] = '4'
tune_and_evaluate(tuning_option)

tune_kernels is correctly, but speed slower,

tune graph is error.

Traceback (most recent call last): File “/home/shanshaojie/project/mmdetection/projects/pruner/tvm/auto_tuning.py”, line 331, in tune_and_evaluate(tuning_option) File “/home/shanshaojie/project/mmdetection/projects/pruner/tvm/auto_tuning.py”, line 243, in tune_and_evaluate tune_graph(mod[“main”], data_shape, log_file, graph_opt_sch_file) File “/home/shanshaojie/project/mmdetection/projects/pruner/tvm/auto_tuning.py”, line 215, in tune_graph executor = Tuner(graph, {input_name: dshape}, records, target_op, target) File “/home/shanshaojie/anaconda3/lib/python3.6/site-packages/tvm-0.10.0-py3.6-linux-x86_64.egg/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py”, line 44, in init super(DPTuner, self).init(*args, **kwargs) File “/home/shanshaojie/anaconda3/lib/python3.6/site-packages/tvm-0.10.0-py3.6-linux-x86_64.egg/tvm/autotvm/graph_tuner/base_graph_tuner.py”, line 178, in init “operator is one of %s” % self._target_ops RuntimeError: Could not find any input nodes with whose operator is one of [Op(nn.conv2d)]