Build model has more parameters than original

Hey,

I’m trying to compile a pytorch model with tvm. These are the steps I use:

input_name = [("input", (input_shape, "float16"))]
modelScripted = torch.jit.trace(model, example_input)
mod, params = relay.frontend.from_pytorch(
    modelScripted, input_name
)
dev = tvm.cuda(0)
target = tvm.target.cuda()
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
m.set_input("input", cinput)
m.run()
tvm_output = m.get_output(0).asnumpy()

The result is of shape (1, 2, 96, 96, 96) which is correct but the values are not the same compared to what I get when using .forward() on modelScripted. Namely, all values of tvm_output[0][0] are the same and all values of tvm_output[0][1] are the same. The value of the weights of the last layer.

I believe the problem is that during the relay.build step extra parameters are created. I think this is the case since len(params) is 44 and len(list(m.get_input_info()[0])) is 51. I believe the later should be 45 since there are 44 parameters and 1 input. When I look at the values with m.get_input() I can see that some (but not all) parameters consist of tensors with exclusively zero values. I believe these cause the output to be equal to the value of the last layer.

Any idea how I could debug this? I’m new to tvm and don’t know where to look.

Can you share your model with a repro script?

Hey, thank you for the response. I’m not able to share the model but I’ve tried making a small example to show what goes wrong:

import torch
from tvm import relay
import tvm
from torch import nn


class Model(nn.Module):
    def __init__(
        self,
    ):
        super(Model, self).__init__()

        self.c1 = nn.Conv3d(1, 1, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.GroupNorm(num_groups=1, num_channels=1)
        self.final_activation = nn.Sigmoid()

        self.apply(init_params)

    def forward(self, x):
        x = self.c1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.final_activation(x)
        return x


def init_params(m):
    if hasattr(m, "weight"):
        m.weight.data = torch.randn(m.weight.size()) * 0.1
    if hasattr(m, "bias"):
        m.bias.data = torch.randn(m.bias.size()) * 0.1


model = Model().half().cuda(0)

input_shape = (1, 1, 96, 96, 96)
input_name = [("input", (input_shape, "float16"))]
input_data = torch.randn(input_shape).half()

modelScripted = torch.jit.trace(model, input_data.cuda())

pytorch_output = modelScripted.forward(input_data.cuda()).detach().cpu().numpy()

mod, params = relay.frontend.from_pytorch(
    modelScripted, input_name, default_dtype="float16"
)
dev = tvm.cuda(0)
target = tvm.target.cuda()

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

# executing
from tvm.contrib import graph_executor

dtype = "float16"
m = graph_executor.GraphModule(lib["default"](dev))


cinput = tvm.nd.array(input_data)
cinput = cinput.copyto(dev)
m.set_input("input", cinput)

m.run()
tvm_output = m.get_output(0).asnumpy()
print(pytorch_output)
print(tvm_output)
print("DONE")

I didn’t manage to reproduce the exact same results. The problem here is that the tvm output consists of only NaN while pytorch returns values.

EDIT: so here the len(m.get_input_info()[0]) is 6 and the len(params) is 4. I believe this should be 5 and 4.

I ran this code and the two outputs mostly agree. Here is the max and mean abs difference:

0.0004883 5.08e-05
1 Like

I don’t now why m.get_input_info() returns

{"p0": [1, 1, 3, 3, 3], "p1": [1, 1, 1, 1], "p2": [], "p4": [1, 1, 1, 1], "p3": [1, 1, 1, 1], "input": [1, 1, 96, 96, 96]}

p0 - p3 are compile-time constants that are embedded directly in the compiled model. So they are not considered as one of “inputs”. There is only one input, named “input” in this model.

1 Like

That’s strange, for me the outputs don’t match. I’ve also tried running it with float32 instead of float16 and then it works. My script:

import torch
from tvm import relay
import tvm
from torch import nn
from tvm.contrib.debugger.debug_executor import GraphModuleDebug

HALF: bool = True


class Model(nn.Module):
    def __init__(
        self,
    ):
        super(Model, self).__init__()

        self.c1 = nn.Conv3d(1, 1, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.GroupNorm(num_groups=1, num_channels=1)
        self.final_activation = nn.Sigmoid()

        self.apply(init_params)

    def forward(self, x):
        x = self.c1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.final_activation(x)
        return x


def init_params(m):
    if hasattr(m, "weight"):
        m.weight.data = torch.randn(m.weight.size()) * 0.1
    if hasattr(m, "bias"):
        m.bias.data = torch.randn(m.bias.size()) * 0.1


model = Model().cuda(0)
if HALF:
    model = model.half()

input_shape = (1, 1, 96, 96, 96)
input_name = [("input", (input_shape, "float16" if HALF else "float32"))]
input_data = torch.randn(input_shape)
if HALF:
    input_data = input_data.half()

modelScripted = torch.jit.trace(model, input_data.cuda())

pytorch_output = modelScripted.forward(input_data.cuda()).detach().cpu().numpy()

mod, params = relay.frontend.from_pytorch(
    modelScripted, input_name, default_dtype=("float16" if HALF else "float32")
)
dev = tvm.cuda(0)
target = tvm.target.cuda()

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

# executing
from tvm.contrib import graph_executor

dtype = "float16" if HALF else "float32"
m = graph_executor.GraphModule(lib["default"](dev))


cinput = tvm.nd.array(input_data)
cinput = cinput.copyto(dev)
m.set_input("input", cinput)

m.run()
tvm_output = m.get_output(0).asnumpy()
print(pytorch_output.flatten()[:10])
print(tvm_output.flatten()[:10])
print(f"num of inputs: {m.get_num_inputs()}")
print(f"num of params: {len(params)}")
print("DONE")

If I set HALF to False they match but when I set it to True tvm_output contains only nan.

I’m using v0.10.0

Some other issue is that tvm is significantly slower as pytorch. I believe this is because I tvm is still using cpu even though I set the target to cuda. I get similar speeds when I used cpu compared to cuda. For the example I gave pytorch takes 0.005 seconds while tvm took 0.015 and this difference becomes bigger for larger models. I’ve also tried to tune the model but I don’t get any performance gains. I also get some warnings. The code I use:

import torch
from tvm import relay
import tvm
from torch import nn
from tvm.contrib.debugger.debug_executor import GraphModuleDebug
import time

HALF: bool = False
TUNING: bool = True


class Model(nn.Module):
    def __init__(
        self,
    ):
        super(Model, self).__init__()

        self.c1 = nn.Conv3d(1, 1, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.GroupNorm(num_groups=1, num_channels=1)
        self.final_activation = nn.Sigmoid()

        self.apply(init_params)

    def forward(self, x):
        x = self.c1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.final_activation(x)
        return x


def init_params(m):
    if hasattr(m, "weight"):
        m.weight.data = torch.randn(m.weight.size()) * 0.1
    if hasattr(m, "bias"):
        m.bias.data = torch.randn(m.bias.size()) * 0.1


model = Model().cuda(0)
if HALF:
    model = model.half()

input_shape = (1, 1, 96, 96, 96)
input_name = [("input", (input_shape, "float16" if HALF else "float32"))]
input_data = torch.randn(input_shape)
if HALF:
    input_data = input_data.half()

modelScripted = torch.jit.trace(model, input_data.cuda())

mod, params = relay.frontend.from_pytorch(
    modelScripted, input_name, default_dtype=("float16" if HALF else "float32")
)

target = tvm.target.Target("cuda")
dev = tvm.device("cuda", 0)


if not TUNING:
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
else:
    from tvm.autotvm.tuner import XGBTuner
    from tvm import autotvm

    number = 10
    repeat = 1
    min_repeat_ms = 100
    timeout = 100  # in seconds

    # create a TVM runner
    runner = autotvm.LocalRunner(
        number=number,
        repeat=repeat,
        timeout=timeout,
        min_repeat_ms=min_repeat_ms,
        enable_cpu_cache_flush=True,
    )

    tuning_option = {
        "tuner": "xgb",
        "trials": 20,
        "early_stopping": 100,
        "measure_option": autotvm.measure_option(
            builder=autotvm.LocalBuilder(build_func="default"), runner=runner
        ),
        "tuning_records": "tuning_records.json",  # TODO : change, log file
    }

    # begin by extracting the tasks from the onnx model
    tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)

    # Tune the extracted tasks sequentially.
    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        tuner_obj = XGBTuner(task, loss_type="rank", num_threads=1)
        tuner_obj.tune(
            n_trial=min(tuning_option["trials"], len(task.config_space)),
            early_stopping=tuning_option["early_stopping"],
            measure_option=tuning_option["measure_option"],
            callbacks=[
                autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
                autotvm.callback.log_to_file(tuning_option["tuning_records"]),
            ],
        )

    with autotvm.apply_history_best(tuning_option["tuning_records"]):
        with tvm.transform.PassContext(opt_level=3, config={}):
            lib = relay.build(mod, target=target, params=params)

# executing
from tvm.contrib import graph_executor

dtype = "float16" if HALF else "float32"
m = graph_executor.GraphModule(lib["default"](dev))

start = time.time()
cinput = tvm.nd.array(input_data)
cinput = cinput.copyto(dev)
m.set_input("input", cinput)

m.run()
tvm_output = m.get_output(0).asnumpy()
print(f"tvm took: {time.time() - start}")

start = time.time()
pytorch_output = modelScripted.forward(input_data.cuda()).detach().cpu().numpy()
print(f"pytorch took: {time.time() - start}")


print(pytorch_output.flatten()[:10])
print(tvm_output.flatten()[:10])
print("DONE")

I get these logs while training:

[Task  1/ 2]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (20/20) | 16.15 sWARNING:root:Could not find any valid schedule for task Task(func_name=conv3d_ncdhw.cuda, args=(('TENSOR', (1, 1, 96, 96, 96), 'float32'), ('TENSOR', (1, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 1, 'float32'), kwargs={}, workload=('conv3d_ncdhw.cuda', ('TENSOR', (1, 1, 96, 96, 96), 'float32'), ('TENSOR', (1, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 1, 'float32')). A file containing the errors has been written to /tmp/tvm_tuning_errors_8w8xy05f.log.
[Task  2/ 2]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (20/20) | 36.38 sWARNING:root:Could not find any valid schedule for task Task(func_name=conv3d_ncdhw_winograd.cuda, args=(('TENSOR', (1, 1, 96, 96, 96), 'float32'), ('TENSOR', (1, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 1, 'float32'), kwargs={}, workload=('conv3d_ncdhw_winograd.cuda', ('TENSOR', (1, 1, 96, 96, 96), 'float32'), ('TENSOR', (1, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 1, 'float32')). A file containing the errors has been written to /tmp/tvm_tuning_errors_ao2ydcb0.log.

and here’s the content of /tmp/tvm_tuning_errors_8w8xy05f.log: Traceback (most recent call last): File "/root/.local/lib/python3.8/site-pack - Pastebin.com

I’m using v0.10.0

Can you try the latest main?

Not sure about your tuning test, but autotvm is very old. I suggest trying meta schedule or auto-scheduler.

1 Like

Thank you for the help. Using the latest main fixes the issue with float16. I’ve also tried tuning the with auto-scheduler and that works great. I get similar speeds as pytorch. Here’s my latest version of the example code in case anyone stumbles upon this question:

from typing import Literal
import torch
from tvm import relay
import tvm
from torch import nn
from tvm.contrib.debugger.debug_executor import GraphModuleDebug
import time

HALF: bool = True
TUNETYPE: Literal[None, "autoTVM", "autoScheduling"] = "autoScheduling"


class Model(nn.Module):
    def __init__(
        self,
    ):
        super(Model, self).__init__()

        self.c1 = nn.Conv3d(1, 1, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.GroupNorm(num_groups=1, num_channels=1)
        self.final_activation = nn.Sigmoid()

        self.apply(init_params)

    def forward(self, x):
        x = self.c1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.final_activation(x)
        return x


def init_params(m):
    if hasattr(m, "weight"):
        m.weight.data = torch.randn(m.weight.size()) * 0.1
    if hasattr(m, "bias"):
        m.bias.data = torch.randn(m.bias.size()) * 0.1


model = Model().cuda(0)
if HALF:
    model = model.half()

input_shape = (1, 1, 96, 96, 96)
input_name = [("input", (input_shape, "float16" if HALF else "float32"))]
input_data = torch.randn(input_shape)
if HALF:
    input_data = input_data.half()

modelScripted = torch.jit.trace(model, input_data.cuda())

mod, params = relay.frontend.from_pytorch(
    modelScripted, input_name, default_dtype=("float16" if HALF else "float32")
)

target = tvm.target.Target("cuda")
dev = tvm.device("cuda", 0)

if TUNETYPE == "autoScheduling":
    from tvm import auto_scheduler

    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)
    for idx, task in enumerate(tasks):
        print(
            "========== Task %d  (workload key: %s) =========="
            % (idx, task.workload_key)
        )
        print(task.compute_dag)

    print("Begin tuning...")
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
        repeat=1, min_repeat_ms=300, timeout=10
    )

    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=200,  # change this to 20000 to achieve the best performance
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile("tuning_records.json")],
    )

    tuner.tune(tune_option)

    print("Compile...")
    with auto_scheduler.ApplyHistoryBest("tuning_records.json"):
        with tvm.transform.PassContext(
            opt_level=3, config={"relay.backend.use_auto_scheduler": True}
        ):
            lib = relay.build(mod, target=target, params=params)

elif TUNETYPE == "autoTVM":
    from tvm.autotvm.tuner import XGBTuner
    from tvm import autotvm

    number = 10
    repeat = 1
    min_repeat_ms = 100
    timeout = 100  # in seconds

    # create a TVM runner
    runner = autotvm.LocalRunner(
        number=number,
        repeat=repeat,
        timeout=timeout,
        min_repeat_ms=min_repeat_ms,
        enable_cpu_cache_flush=True,
    )

    tuning_option = {
        "tuner": "xgb",
        "trials": 20,
        "early_stopping": 100,
        "measure_option": autotvm.measure_option(
            builder=autotvm.LocalBuilder(build_func="default"), runner=runner
        ),
        "tuning_records": "tuning_records.json",  # TODO : change, log file
    }

    # begin by extracting the tasks from the onnx model
    tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)

    # Tune the extracted tasks sequentially.
    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        tuner_obj = XGBTuner(task, loss_type="rank", num_threads=1)
        tuner_obj.tune(
            n_trial=min(tuning_option["trials"], len(task.config_space)),
            early_stopping=tuning_option["early_stopping"],
            measure_option=tuning_option["measure_option"],
            callbacks=[
                autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
                autotvm.callback.log_to_file(tuning_option["tuning_records"]),
            ],
        )

    with autotvm.apply_history_best(tuning_option["tuning_records"]):
        with tvm.transform.PassContext(opt_level=3, config={}):
            lib = relay.build(mod, target=target, params=params)
else:
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
# executing
from tvm.contrib import graph_executor

dtype = "float16" if HALF else "float32"
m = graph_executor.GraphModule(lib["default"](dev))

start = time.time()
cinput = tvm.nd.array(input_data)
cinput = cinput.copyto(dev)
m.set_input("input", cinput)

m.run()
tvm_output = m.get_output(0).asnumpy()
print(f"tvm took: {time.time() - start}")

start = time.time()
pytorch_output = modelScripted.forward(input_data.cuda()).detach().cpu().numpy()
print(f"pytorch took: {time.time() - start}")


print(pytorch_output.flatten()[:10])
print(tvm_output.flatten()[:10])
print(f"num of inputs: {m.get_num_inputs()}")
print(f"num of params: {len(params)}")
print("DONE")

I was also wondering if you know a place where I could find an example on how to use the meta schedule. I couldn’t find anything in the docs and from what I can see of it, it might be interesting to try as well :slight_smile: