Poor Performance with Tuning Large-Scale Tasks(nn.Linear) in Auto-Scheduler

Hello,

I’m currently facing some performance tuning challenges when using the TVM auto-scheduler on tasks with large-scale model from PyTorch. I’ve noticed that the tuning process does not yield the expected performance improvements and often results in longer inference times compared to direct inference in PyTorch.

Here’s a brief overview of the problem setup:

Task Description: The specific operation involves a dense matrix multiplication followed by an addition:

p0 = PLACEHOLDER [8, 3584]
p1 = PLACEHOLDER [152064, 3584]
T_dense(i, j) += (p0[i, k]*p1[j, k])
p2 = PLACEHOLDER [1, 152064]
T_add(ax0, ax1) = (T_dense[ax0, ax1] + p2[0, ax1])

Hardware and Setup: I’m running these tasks on a Nvidia 4090 GPU. Tuning Process: I’ve attempted to tune the model following the tutorial from Auto-scheduling a Neural Network for NVIDIA GPU — tvm 0.18.dev0 documentation with TuningOptions below:

    measure_ctx = auto_scheduler.LocalRunner(number=3, repeat=1, min_repeat_ms=100, timeout=100)

    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=36000,  
        num_measures_per_round=128,
        runner=measure_ctx,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

I’m puzzled by the lack of efficiency in tuning and am wondering if I might be missing some crucial optimizations or if there are inherent limitations with the auto-scheduler for tasks of this scale.

Thanks!!

hi @qiaoming , for nvidia devices with an architecture greater than sm_70, dense should levearge tensor core for better performance, auto scheduler can only tune for cuda cores. Considering the meta schedule with TensorIR to utilize tensor core for 4090.

from tvm import meta_schedule as ms

database = ms.tune_tir(
            mod=mod,
            target=target,
            max_trials_global=trails,
            num_trials_per_iter=16,
            work_dir=workdir,
            space=ms.space_generator.PostOrderApply(
                sch_rules="cuda-tensorcore",
                postprocs="cuda-tensorcore",
                mutator_probs="cuda-tensorcore"
            )
        )

Or you can also checkout the dlight or fastdlight(currently I temporarily make it into a new project Microsoft/BitBLAs) to get a high performance kernel quickly.

Thank you very much for your advice! I encountered a problem when using ms.tune_tir. The IRModule I passed in was obtained through relay.frontend.from_pytorch. I noticed that the functions_items in my IRModule do not contain tasks of type tir.PrimFunc, which prevents me from training. How should I go about converting it?

@qiaoming , maybe you can check out the last one (FastDlight or BitBLAS), that we also implemented a lightweight and efficient tuner for tir script.

from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.arch import CUDA
from bitblas.base.utils import apply_and_build
@tvm.script.ir_module
class MatmulNT:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, K], dtype=in_dtype)
        B = T.match_buffer(b, [N, K], dtype=in_dtype)
        C = T.match_buffer(c, [M, N], dtype=out_dtype)

        for i, j, k in T.grid(M, N, K):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = tvm.tir.const(0, out_dtype)
                C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[
                    vj, vk
                ].astype(out_dtype)

ir_module = MatmulNT
func = ir_module["main"]
target = tvm.target.Target("nvidia/nvidia-a100")
arch = CUDA(target)
# Tune with SIMT Cuda Core
policy = DefaultPolicy(func=func, arch=arch)
try:
    tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
except Exception:
    tags = None
# Tune with Tensor Core if possible
if tags:
    policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)

configs = policy.emit_config(topk=20)

cpresults, best = apply_and_build(func, configs, arch, parallel_build=True)

[BitBLAS] Evaluation with config  {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.032 ms
[BitBLAS] Evaluation with config  {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.021 ms
[BitBLAS] Evaluation with config  {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.023 ms
[BitBLAS] Evaluation with config  {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.023 ms
[BitBLAS] Evaluation with config  {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.027 ms
[BitBLAS] Evaluation with config  {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.025 ms
[BitBLAS] Evaluation with config  {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.023 ms
[BitBLAS] Evaluation with config  {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.025 ms
[BitBLAS] Evaluation with config  {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.037 ms
[BitBLAS] Evaluation with config  {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.037 ms
[BitBLAS] Evaluation with config  {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.026 ms
[BitBLAS] Evaluation with config  {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.043 ms
[BitBLAS] Evaluation with config  {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.042 ms
[BitBLAS] Evaluation with config  {'block': [32, 256], 'warp': [16, 128], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.025 ms
[BitBLAS] Evaluation with config  {'block': [256, 32], 'warp': [128, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.029 ms
[BitBLAS] Evaluation with config  {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.028 ms
[BitBLAS] Evaluation with config  {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.027 ms
[BitBLAS] Evaluation with config  {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.044 ms
[BitBLAS] Evaluation with config  {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.040 ms
[BitBLAS] Evaluation with config  {'block': [16, 256], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}
[BitBLAS] Time cost of this config: 0.047 ms
print(best.code)
'''
extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) {
  ...
}
'''

@LeiWang1999 Thank you, I will start trying out BitBLAS. By the way, is there no way to directly extract tir.PrimFunc from Relay?

@qiaoming, I think there does not exist straightforward way to convert a relay layer to a tir expression, as the op implementations are mostly registered by te implementation via topi. But you can easily lower a te to tir primfunc.

te_expr = tvm._ffi.get_global_func("relay.backend.LowerToTE")(call.op)
workload = te.create_prim_func(te_expr)
ir_module = tvm.IRModule({"main": workload})

@LeiWang1999 Thank you for your suggestion. I found that when I create a matrix multiplication directly with te, like this:

func = te.create_prim_func(matmul_fp16(N=N, M=M, K=K, out_dtype=dtype)).with_attr(
    {"global_symbol": "main"}
)
ir_mod = tvm.IRModule({"main": func})

ms.tune_tir can successfully run. But when I pass in a single-layer PyTorch nn.Linear structure:

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
mod = relay.transform.InferType()(mod)
te_expr = tvm._ffi.get_global_func("relay.backend.LowerToTE")(mod["main"])
workload = te.create_prim_func(te_expr)
ir_module = tvm.IRModule({"main": workload})

Using ms.tune_tir results in an error: TVMError: Do not have a default for tir.Evaluate Does this mean tune_tir does not support tasks converted from Torch models?

do not have experience about the pytorch interface, might be legacy, a more common approach is to export the model to an onnx file first, and then convert it to Relay IR.

@LeiWang1999 I see, thanks a lot!