Hi @apeskov!
Yes, sure! I’ll clean up my code so I can extract the patch out of it, but the change was rather simple:
@contextlib.contextmanager
def default_module_loader_mgr(remote_kwargs, build_result):
remote = request_remote(**remote_kwargs)
# if pre_load_function is not None:
# pre_load_function(remote, build_result)
print("Remote upload = ", build_result.filename)
remote.upload(build_result.filename)
try:
yield remote, remote.load_module(os.path.split(build_result.filename)[1])
finally:
# clean up remote files
remote.remove(build_result.filename)
remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
remote.remove("")
print("Clean up = ", build_result.filename)
def default_module_loader(pre_load_function=None):
"""Returns a default function that can be passed as module_loader to run_through_rpc.
Parameters
----------
pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
Invoked after a session is established and before the default code-loading RPC calls are
issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
Returns
-------
ModuleLoader :
A function that can be passed as module_loader to run_through_rpc.
"""
return default_module_loader_mgr
I’m not good with Python
(yet), but what I understood is that Python
has problems pickling nested functions since it gets complicated to maintain a pointer to those. So I have just removed default_module_loader_mgr
outside of default_module_loader
method score and that did the trick