Incorrect running performance of MetaSchedule

I built TVM from source code on a cluster equipped with different GPUs/cuda and ran MetaSchedule/Ansor to tune the C1D operator. However, the running performance of the same schedule of MetaSchedule is much slower than Ansor. I rebuilt the same version of TVM on my server and the running time of the same schedule of MetaSchedule is similar to Ansor. Is there any possible reason for it? Or it is a bug.

from typing import TYPE_CHECKING, Callable, List, Optional, Union, Tuple
import sys
import time
from tvm.runtime import Device, Module, ndarray
# Add the directory containing your self-built package to the Python path
import tvm
from tvm.meta_schedule.utils import cpu_count, derived_object
from tvm.meta_schedule.runner import RunnerResult
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm import tir
from tvm.tir.schedule import Schedule, Trace
from tvm.meta_schedule import _ffi_api
from tvm import auto_scheduler, te
from tvm.meta_schedule.arg_info import ArgInfo
from tvm.meta_schedule.builder import (
    BuilderInput,
    BuilderResult,
    LocalBuilder,
    PyBuilder,
)
import pickle
from tvm.meta_schedule.testing.te_workload import create_te_workload
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import (
    check_sketches,
    print_sketches,
    generate_design_space,
)
from tvm.meta_schedule.testing.te_workload import create_te_workload
from tvm.script import tir as T
from tvm.target import Target
from tvm import te, tir, topi
from tvm import relay, auto_scheduler
from typing import Tuple
from tvm.meta_schedule.runner.local_runner import *
import tvm.meta_schedule as ms
from tvm.meta_schedule.measure_callback import PyMeasureCallback
from typing import List
import os
import json
from tvm.meta_schedule.runner import RunnerResult
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm import tir
import math
import random
import copy
from tvm.tir.schedule import Trace
import tvm
from tvm import auto_scheduler, te
from tvm.meta_schedule.search_strategy import SearchStrategy, MeasureCandidate, ReplayTrace, PySearchStrategy
from tvm.meta_schedule.arg_info import ArgInfo


import tvm.meta_schedule as ms
from tvm.meta_schedule.measure_callback import PyMeasureCallback
from typing import List
from tvm.meta_schedule.builder import (
    BuilderInput,
    BuilderResult,
    LocalBuilder,
    PyBuilder,
)
from tvm.meta_schedule.runner import (
    EvaluatorConfig,
    LocalRunner,
    PyRunner,
    RPCConfig,
    RPCRunner,
    RunnerFuture,
    RunnerInput,
)
from tvm.tir import FloatImm
from tvm.meta_schedule.utils import (
    derived_object,
    get_global_func_with_default_on_worker,
)
os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"

from tvm.meta_schedule.runner.utils import (
    T_ARG_INFO_JSON_OBJ_LIST,
    T_ARGUMENT_LIST,
    alloc_argument_common,
    run_evaluator_common,
)
from tvm.meta_schedule.runner import *


def _target():
    return Target("nvidia/nvidia-a100")

def _design_space(mod):
    return generate_design_space(
        kind="cuda",
        mod=mod,
        target=Target("nvidia/nvidia-a100"),
        types=ms.ScheduleRule,
    )


def build_decision(sch, context, decision):
    build_cands = []
    fail_build = []
    s = time.time()
    new_decision = {}
    tmp_decision = {}
    for i in decision:
        tmp_decision[str(i.outputs)] = decision[i]
    for i in sch.trace.decisions:
        if str(i.outputs) not in tmp_decision:
            new_decision[i] = sch.trace.decisions[i]
        else:
            new_decision[i] = tmp_decision[str(i.outputs)]
    new_sch = Schedule(context.mod)
    Trace(insts = sch.trace.insts, decisions = new_decision).apply_to_schedule(new_sch, remove_postproc=True)
    new_sch.enter_postproc()
    failed = False
    for postproc in context.space_generator.postprocs:
        if not postproc.apply(new_sch):
            failed = True
            return False
    return MeasureCandidate(new_sch, ArgInfo.from_entry_func(new_sch.mod, True))


mod = create_te_workload("C1D", 0)
context = ms.TuneContext(
                mod=mod,
                target=Target(f"nvidia/nvidia-a100"),
                space_generator='post-order-apply',
                search_strategy='replay-func',
                task_name='main',
                rand_state=0,
                num_threads=1)
schs = context.generate_design_space()

builder = LocalBuilder(max_workers=32)
evaluator_config = EvaluatorConfig(
    repeat=1,
    min_repeat_ms=100,
    enable_cpu_cache_flush=False,
)
runner = LocalRunner(timeout_sec = 25, evaluator_config = evaluator_config)

decision = schs[0].trace.decisions
sch_idx = 0
change_parts = [[[1,1,1,1,1], [4,4,8,1,1], [8,1,16,1,1], [3,1,1], [1,1,64], 2,3,4]]

cands = []
for i in range(len(change_parts)):
    tmp_decision = {}
    inst_count = 0
    for inst in schs[0].trace.insts:
        if inst in decision:
            tmp_decision[inst] = change_parts[i][inst_count]
            inst_count += 1
    tmp_cand = build_decision(schs[sch_idx], context, tmp_decision)
    cands.append(tmp_cand)

build_cands = cands
s = time.time()
mods = [cand.sch.mod for cand in build_cands]
args_infos = [cand.args_info for cand in build_cands]
print([[cand.sch.trace.decisions[i] for i in cand.sch.trace.insts if i in cand.sch.trace.decisions] for cand in build_cands])
tmp = [i.as_json() for i in args_infos[0]]
device = tvm.runtime.device(dev_type="cuda", dev_id=0)


builder_inputs = [BuilderInput(mod, Target("cuda")) for mod in mods]
repeated_args = tvm.meta_schedule.runner.local_runner.default_alloc_argument(
    device,
    tmp,
    1,
)
builder_results = builder.build(builder_inputs)
rt_mod = tvm.runtime.load_module(builder_results[0].artifact_path)
costs: List[float] = tvm.meta_schedule.runner.local_runner.default_run_evaluator(
    rt_mod,
    device,
    evaluator_config,
    repeated_args,
)
print(costs)

builder_inputs = [BuilderInput(mod, Target("cuda")) for mod in mods]
builder_results = builder.build(builder_inputs)
repeated_args = tvm.meta_schedule.runner.local_runner.default_alloc_argument(
    device,
    tmp,
    1,
)
rt_mod = tvm.runtime.load_module(builder_results[0].artifact_path)
costs: List[float] = tvm.meta_schedule.runner.local_runner.default_run_evaluator(
    rt_mod,
    device,
    evaluator_config,
    repeated_args,
)
print(costs)

For anyone who is interested in this problem. This is a script to reproduce it. When I run this script on the cluster, I get:

[9.24965836803402e-06]
[6.47473028956422e-05]

While on my server, I get:

[6.740663960949926e-06]
[6.752736001940805e-06]

The only difference between the two costs is the order of builder_results and generating repeated_args. If generating repeated_args is ahead of builder_results, the cluster will return a much slower result without influence on my server.

Temporary solution: replace “tvm.contrib.random.random_fill_for_measure” in python/tvm/meta_schedule/runner/local_runner.py with “tvm.contrib.random.random_fill”.

Ansor works well because it uses “tvm.contrib.random.random_fill” to generate random input while MetaSchedule uses “tvm.contrib.random.random_fill_for_measure”.

Hi thanks for reporting this, can you please clarify how did you compare Ansor results against MetaSchedule result and did you use different repeated arg function for these two results?

I’m using TVM 0.16.

For comparing Ansor against MetaSchedule, I manually generate the same schedule for Ansor and MetaSchedule but find that MetaSchedule is much slower.

For repeated arg function, Ansor uses “tvm.contrib.random.random_fill” and MetaSchedule “tvm.contrib.random.random_fill_for_measure” as in the official code.

I think there is something wrong when generating random input on GPU device. The first observation is that “random_fill_for_measure” can lead to a slower running result than “random_fill”. The second one is that the running result of “random_fill_for_measure” is influenced by the order. If you build the model first, then the result will be slower than generating random input first, as shown in my script above. Besides, the bug only appears in my cluster machine which uses CUDA 12.5. In my own machine with CUDA 11.8, it works well.