在论文中偶然读到一些方法会用到模型中间的隐藏层作为分类器,与模型最后一层作为分类器的性能进行对比,故而思考如何能够简便快捷地实现将模型某一层的输出输出拉取出来的方法,发现有现成hook函数可以做到这一点。
hook
hook就是一个钩子,用来把网络中的某一层的输入输出或者其他信息钩出来,如果想知道网络中某一层的详细信息,不用在定义网络时单独写一个print,直接写一个hook函数即可。
register_forward_hook
源代码里说明,hook只能用在forward()函数运行之前,写在forward函数运行之后是没用的,意思是想要运行hook,先把hook的函数写好,然后再实例化网络
def register_forward_hook(self, hook):
r'''Registers a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
'''
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
问题
模型中有时会出现多个Linear层,但net.children()提取出来的所有类型一致的模块其名称也一致,故根据当前Linear层的输入和输出维度进行判断,精确锁定到该层,其他模块也依然适用
代码部分
import torch
import torch.nn as nn
class TestForHook(nn.Module):
def __init__(self):
super().__init__()
self.linear_1=nn.

本文介绍如何在PyTorch中使用hook函数便捷地访问模型隐藏层的输出,通过实例化`TestForHook`模型并针对特定Linear层注册forward_hook,实现在网络运行时获取层的输入和输出。

1008

被折叠的 条评论
为什么被折叠?



