(即插即用模块-特征处理部分) 二十八、(2024 AAAI) GAU 门控注意力单元

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

在这里插入图片描述

paper:Gated Attention Coding for Training High-Performance and Efcient Spiking Neural Networks

Code:https://github.com/bollossom/GAC


1、Gated Attention Unit

现阶段直接编码存在一些局限性:即周期性输出: 直接编码在每个时间步重复生成相同的浮点数,导致脉冲表示缺乏时间动态,无法有效模拟人类视觉系统对动态视觉信息的处理。无力脉冲表示: 由于缺乏时间动态,直接编码生成的脉冲表示信息量有限,导致 SNN 参数敏感性较高,性能受限。效率低下: 直接编码需要大量时间步才能维持高性能,导致模拟时间和能耗增加。为此,这篇论文提出一种 门控注意力单元(Gated Attention Unit)。GAU 通过利用多维度注意力机制进行门控,将静态数据高效地编码为具有时间动态的强大表示。

对于输入X,GAU的实现过程:

  1. 提取图像特征: 输入特征输入 GAU 模块后,首先通过一个卷积层提取图像特征。
  2. 时间注意力: 通过 Avgpool 和 Maxpool 计算输入特征图的最大值和平均值,并使用 MLP 网络将其转换为时间权重向量。
  3. 空间通道注意力: 使用共享的 2D 卷积操作在每个时间步获取空间通道矩阵。
  4. 门控: 将时间权重向量广播到特征图大小,并与空间通道矩阵进行 Hadamard 乘积,生成最终编码结果。

GAU 的优势:

  • 生成具有时间动态的编码结果,提升 SNN 的表现能力。
  • 作为预处理层,不破坏 SNN 的脉冲驱动特性,易于类脑硬件实现。
  • 与其他 SNN 注意力方法相比,参数量更少,计算效率更高。

Gated Attention Unit 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn


class TA(nn.Module):
    def __init__(self,  T,ratio=2):

        super(TA, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.sharedMLP = nn.Sequential(
            nn.Conv3d(T, T // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv3d(T // ratio, T, 1,bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = self.avg_pool(x)
        # B,T,C
        out1 = self.sharedMLP(avg)
        max = self.max_pool(x)
        # B,T,C
        out2 = self.sharedMLP(max)
        out = out1+out2

        return out

# task  generation
class SCA(nn.Module):
    def __init__(self, in_planes, kerenel_size=3,ratio = 1):
        super(SCA, self).__init__()
        self.sharedMLP = nn.Sequential(
                nn.Conv2d(in_planes, in_planes // ratio, kerenel_size, padding='same', bias=False),
                nn.ReLU(),
                nn.Conv2d(in_planes // ratio, in_planes, kerenel_size, padding='same', bias=False),)
    def forward(self, x):
        b,t, c, h, w = x.shape
        x = x.flatten(0,1)
        x = self.sharedMLP(x)
        out = x.reshape(b,t, c, h, w)
        return out


if __name__ == '__main__':
    x = torch.randn(4, 10, 64, 128, 128).cuda()
    ta = TA(10).cuda()
    sca = SCA(64).cuda()
    out_ta = ta(x)
    out_sca = sca(x)
    print(out_ta.shape)
    print(out_sca.shape)

您可能感兴趣的与本文相关的镜像

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen
文本生成
Qwen3

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值