[Pytorch] register_forward_hook support

I would like to compile a Pytorch model that had forward_hooks inserted. Does TVM support something like this? My goal is to use hooks to return to the CPU side periodically.

I’m totally new to TVM. Any suggestions would be valuable! Thanks!

No it is not supported. You have to trace your model via torch.jit.trace(...).

Thanks for the answer! Is there anyway to add hooks within TVM’s runtime?

Probably not, but what exactly do you want to do with the hooks? The way TVM and PyTorch work are very different. If adding hooks involve Python, it’s not going to work.

We want to implement GPU preemption during inference execution. The way we do it in Pytorch is to use hooks as exit points. I wonder if we can do something like calling “cudaDeviceSynchronize” between CUDA kernels within TVM compiled models?

If that’s all you need, it is probably not difficult to support. Actually I might have other use cases for runtime hook mechanism. In quantization, I often want to look at the values of intermediate tensors, for calculating quantization parameters or figuring out which layers are causing most accuracy loss. If I can pass a user defined Python function to the runtime and have it called after every layer, that would be very useful for me. Do you know if PyTorch register_forward_hook would enable something like that?

Pytorch dose support forward hook for torch.jit.trace(...).

For details, you can check: https://github.com/pytorch/pytorch/issues/34329 and https://github.com/pytorch/pytorch/pull/49544 .

For usage, there is a test file from Pytorch: https://github.com/pytorch/pytorch/blob/5c23888953d277041b341d38dcd5b2d891619ba4/test/jit/test_hooks.py .

I personally think that a hook mechanism is useful, as it will be convenient if we can get intermediate output for debugging (and for cases like quantization accuracy checking, as you have mentioned). Pytorch itself does support this feature, however, it seems that we can’t do the same thing for TVM for now. I will explain a little bit:

To actually get the intermediate result, one way is to just “print” the intermediate tensor in the hook. You can use torch.jit.trace to compile a PyTorch model with print function inside a hooker. However, TVM will give you an error saying that some functions are not implemented:

The following operators are not implemented: ['prim::Print']

Another way is to create a python class like:

class HookRecorder:
    def __init__(self):
        self.recorder = dict() # Get intermediate tensor from the recorder
        self.handlers = list()
    
    def _register_hooker(self, name):
        self.recorder[name] = list()
        def named_hooker(module, input: Tuple[torch.Tensor], output: torch.Tensor):
            self.recorder[name].append(output)
        return named_hooker
    
    def register_hookers(self, target_sub_modules, layer_names):
        for i in range(len(layer_names)):
            module = target_sub_modules[i]
            layer_name = layer_names[i]
            handler = module.register_forward_hook(self._register_hooker(layer_name))
        self.handlers.append(handler)
        
    def remove_handlers(self):
        for i in self.handlers:
            i.remove()
        self.handlers.clear()
        
    def __del__(self):
        self.remove_handlers()

hook = HookRecorder()
hook.register_hookers([net.conv2], ["conv2"])
out = net(input)
print(hook.recorder)

In this way, we can indeed get intermediate values from the python class. However, this can not be compiled by torch.jit.trace.