I am attempting to tune the pretrained AlexNet model from torchvision
using meta schedule. My minimal script (see below) is based off of one of the meta schedule unit tests (test_meta_schedule_tune_relay.py).
Most of the time, executing this script leads to the intended result: the network is successfully tuned via meta schedule. However, a small fraction of the time the script fails with the following error:
Traceback (most recent call last):
File "min_error.py", line 23, in <module>
rt_mod: tvm.module = tune_relay(
File "/workspaces/tvm/python/tvm/meta_schedule/tune.py", line 713, in tune_relay
task_scheduler.tune()
File "/workspaces/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 61, in tune
_ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member
File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
8: TVMFuncCall
7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
6: tvm::meta_schedule::TaskSchedulerNode::Tune()
5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
4: tvm::meta_schedule::RandomComputeLocationNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
3: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
2: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
1: tvm::tir::ScheduleStateNode::DebugVerify() const
0: tvm::tir::VerifyCachedFlags(tvm::tir::ScheduleState const&)
File "/workspaces/tvm/src/tir/schedule/analysis/verify.cc", line 236
TVMError: Schedule verification failed. The IR is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(placeholder: T.Buffer[(1, 96, 1, 1, 4), "float32"], placeholder_1: T.Buffer[(96, 48, 3, 3, 4, 4), "float32"], placeholder_2: T.Buffer[(1, 48, 13, 13, 4), "float32"], T_relu: T.Buffer[(1, 96, 13, 13, 4), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":32, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
data_pad = T.alloc_buffer([1, 48, 15, 15, 4], dtype="float32")
conv2d_NCHWc = T.alloc_buffer([1, 96, 13, 13, 4], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i4_0, i0_1, i1_1, i2_1, i3_1, i4_1, i5_0, i6_0, i7_0 in T.grid(1, 3, 1, 1, 2, 1, 8, 1, 1, 2, 32, 1, 3):
for ax0 in T.serial(1):
for ax1, ax2, ax3, ax4 in T.grid(T.min(47 - i5_0 * 6 // 4, (i5_0 * 6 % 4 + 5) // 4) + 1, 15, 13, 4):
with T.block("data_pad"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(48, i5_0 * 6 // 4 + ax1)
i2 = T.axis.spatial(15, ax2)
i3 = T.axis.spatial(15, i7_0 + ax3)
i4 = T.axis.spatial(4, ax4)
T.reads(placeholder_2[i0, i1, T.max(i2 - 1, 0) : T.max(i2 - 1, 0) + (T.min(i2, 13) - T.max(i2 - 1, 0)), T.max(i3 - 1, 0) : T.max(i3 - 1, 0) + (T.min(i3, 13) - T.max(i3 - 1, 0)), i4])
T.writes(data_pad[i0, i1, i2, i3, i4])
data_pad[i0, i1, i2, i3, i4] = T.if_then_else(1 <= i2 and i2 < 14 and 1 <= i3 and i3 < 14, placeholder_2[i0, i1, i2 - 1, i3 - 1, i4], T.float32(0), dtype="float32")
for i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 1, 13, 1, 6, 3, 1, 1, 4, 13, 1, 1):
with T.block("conv2d_NCHWc"):
n = T.axis.spatial(1, 0)
oc_chunk = T.axis.spatial(96, i1_0 * 32 + i1_1 * 4 + i1_3)
oh, ow = T.axis.remap("SS", [i2_3, i3_2])
oc_block = T.axis.spatial(4, i4_0 * 2 + i4_1)
ic = T.axis.reduce(192, i5_0 * 6 + i5_1)
kh, kw = T.axis.remap("RR", [i6_1, i7_0])
T.reads(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block], data_pad[n, ic // 4, oh + kh, ow + kw, ic % 4], placeholder_1[oc_chunk, ic // 4, kh, kw, ic % 4, oc_block])
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS", "workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 48, 13, 13, 4], "float32"], ["TENSOR", [96, 48, 3, 3, 4, 4], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW4c", "NCHW4c", "float32"]})
with T.init():
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 4, oh + kh, ow + kw, ic % 4] * placeholder_1[oc_chunk, ic // 4, kh, kw, ic % 4, oc_block]
for i0, i1, i2, i3, i4 in T.grid(1, 96, 13, 13, 4):
with T.block("T_relu"):
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(conv2d_NCHWc[ax0, ax1, ax2, ax3, ax4], placeholder[ax0, ax1, 0, 0, ax4])
T.writes(T_relu[ax0, ax1, ax2, ax3, ax4])
T_relu[ax0, ax1, ax2, ax3, ax4] = T.max(conv2d_NCHWc[ax0, ax1, ax2, ax3, ax4] + placeholder[ax0, ax1, 0, 0, ax4], T.float32(0))
The errors are:
- Wrong region_cover: (conv2d_NCHWc, expected=0, actual=1)
- Wrong stage_pipeline: (root, expected=0, actual=1)
I expect this might be due to random sampling within meta schedule, but would appreciate any help in getting around this.
Thanks!
Script causing the error:
import tempfile
import torch
import torchvision
import tvm
from tvm import relay
from tvm.target.target import Target
from tvm.meta_schedule import ReplayTraceConfig
from tvm.meta_schedule.tune import tune_relay
indata = torch.rand((1, 3, 224, 224))
model = torchvision.models.alexnet(pretrained=True)
model.eval()
scripted_model = torch.jit.trace(model, indata)
scripted_model.eval()
shape_list = [("input0", indata.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
with tempfile.TemporaryDirectory() as work_dir:
target = Target("llvm --num-cores=2")
rt_mod: tvm.module = tune_relay(
mod=mod,
params=params,
target=target,
config=ReplayTraceConfig(
num_trials_per_iter=32,
num_trials_total=32,
),
work_dir=work_dir
)
Configuration details:
- TVM commit: 22c488e3a829ad700de6547be6096fb2d1f02e81
- PyTorch version: 1.7.0
-
llc --version | grep "Host CPU"
: icelake-client