GraphModule failed to load params

Hi there, I was trying to deploying with TVM with slight changes on model weights, however, it seems that the GraphModule always yields the same output even the params are updated.

g = graph_executor.GraphModule(lib["default"](dev))
g.set_input(input_name, data)
g.run()

output = g.get_output(0)
print(np.mean(output.numpy()) * 10000000 )
#### 94.54726750846021

# try to change the weights
tmp = params["fc_weight"].numpy()
params["fc_weight"] = tvm.nd.array(np.random.randn(*tmp.shape), tvm.cpu(0))
params["fc_weight"] = tvm.nd.array(np.zeros_like(tmp.shape), tvm.cpu(0))

g.load_params(tvm.runtime.save_param_dict(params))
g.set_input(input_name, data)
g.run()
output = g.get_output(0)
print(np.mean(output.numpy()) * 10000000 )
#### 94.54726750846021

The output keeps the same after new parameters are loaded. I have confirmed that the issue appear on both X86 and ARM arch. The TVM I am using is the based on commit 6720d3593d4dac6015418d4b7e9ad875bbf0b0a2 (submitted on Jan 24 2022)

The complete code to reproduce is attached here

import os, os.path as osp
from copy import deepcopy

import torch
import torch.nn as nn
import torchvision
from torchvision import models

import numpy as np 
from copy import deepcopy

import tvm
from tvm import relay, te
from tvm.contrib import graph_executor

target = "llvm"

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
input_name = "input0"
shape_list = [(input_name, input_data.shape)]

model = models.resnet18(pretrained=True)
model = model.eval()

data = np.random.randn(*input_shape)
dev = tvm.cpu(0)

scripted_model = torch.jit.trace(model, input_data).eval()

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list, use_parser_friendly_name=True)
target = "llvm"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

g = graph_executor.GraphModule(lib["default"](dev))

with open("tmp.params", "wb") as fp:
    bparams = tvm.runtime.save_param_dict(params)
    fp.write(bparams)

# execute
g.set_input(input_name, data)
g.run()

output = g.get_output(0)
print(np.mean(output.numpy()) * 10000000 )
print(np.mean(params["fc_weight"].numpy()))

tmp = params["fc_weight"].numpy()
params["fc_weight"] = tvm.nd.array(np.random.randn(*tmp.shape), tvm.cpu(0))
params["fc_weight"] = tvm.nd.array(np.zeros_like(tmp.shape), tvm.cpu(0))

print(np.mean(params["fc_weight"].numpy()))

g.load_params(tvm.runtime.save_param_dict(params))
g.set_input(input_name, data)
g.run()
output = g.get_output(0)
print(np.mean(output.numpy()) * 10000000 )

This is probably because we embed (or bind, as it is called in the source code) all constant params into the mod itself. You can check the output graph, libs, params = relay.build(...) and check if params is empty.

Hi @masahi, when I called

with tvm.transform.PassContext(opt_level=3):
    graph, libs, params = relay.build(mod, target=target, params=params)
type(graph), type(libs), type(params)
# (str, tvm.runtime.module.Module, dict)

with tvm.transform.PassContext(opt_level=3):
    graph, libs, params = relay.build(mod, target=target, params=None)
type(graph), type(libs), type(params)
# (str, tvm.runtime.module.Module, dict)

No matter the params is set or not, the returned value is not None.

Yeah, can you check if params is an empty dict?

Yes, in the latter example (the input params is set to None), the return params is an empty dict.

Yes this means the params are hard-coded into the compiled lib. If you want to change them at runtime you need to make it an additional input (remove the param you want to change in relay.build(...) and do set_input). Note that this prevents compile-time computation on weight, so perf might get worse.

Seems to find a solution for this.

The load_params does not work properly because the keys are mapped to p0, p1, p2 … rather than the original name.

In my previous example, if change the key to

v = lib_params["p60"].numpy()
lib_params["p60"] = tvm.nd.array(np.zeros_like(v), tvm.cpu(0))

g.load_params(tvm.runtime.save_param_dict(lib_params))
g.run()
g.get_output(0).numpy()
# result all zero.