利用register_forward_hook()精确定位到模型某一层的输入和输出

本文介绍如何在PyTorch中使用hook函数便捷地访问模型隐藏层的输出,通过实例化`TestForHook`模型并针对特定Linear层注册forward_hook,实现在网络运行时获取层的输入和输出。

在论文中偶然读到一些方法会用到模型中间的隐藏层作为分类器,与模型最后一层作为分类器的性能进行对比,故而思考如何能够简便快捷地实现将模型某一层的输出输出拉取出来的方法,发现有现成hook函数可以做到这一点。

hook

hook就是一个钩子,用来把网络中的某一层的输入输出或者其他信息钩出来,如果想知道网络中某一层的详细信息,不用在定义网络时单独写一个print,直接写一个hook函数即可。

register_forward_hook

源代码里说明,hook只能用在forward()函数运行之前,写在forward函数运行之后是没用的,意思是想要运行hook,先把hook的函数写好,然后再实例化网络

def register_forward_hook(self, hook):
        r'''Registers a forward hook on the module.
        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::
            hook(module, input, output) -> None or modified output
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
       '''
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

问题

模型中有时会出现多个Linear层,但net.children()提取出来的所有类型一致的模块其名称也一致,故根据当前Linear层的输入和输出维度进行判断,精确锁定到该层,其他模块也依然适用

代码部分

import torch
import torch.nn as nn
class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1=nn.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值