转载:torch.stack()函数解释

本文介绍了PyTorch中的stack()函数,该函数用于沿着新维度对张量序列进行连接,常用于NLP和CV中,以保留序列和张量矩阵信息。stack()要求输入序列内的张量形状相等,且可以选择任意维度进行拼接。通过示例展示了不同维度拼接的效果,强调了在循环神经网络中stack()的作用,用于处理RNN输出的序列数据。

在pytorch中,常见的拼接函数主要是两个,分别是:

stack()
cat()
实际使用中,这两个函数互相辅助,使用场景不同:关于cat()参考torch.cat(),但是本文主要说stack()。

函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。

形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。

该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

1. stack()

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

outputs = torch.stack(inputs, dim=?) → Tensor
参数

inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。

dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

2. 重点

函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等
----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
不懂的看例子,再回过头看就懂了。

3. 例子

1.准备2个tensor数据,每个的shape都是[3,3]

flag = True
# flag = False

if flag:
    # 假设是时间步T1的输出
    T1 = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
    # 假设是时间步T2的输出
    T2 = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])

    my_stack_00 = torch.stack((T1,T2),dim=0)
    print("\nmy_stack_00:{} shape:{}".format(my_stack_00, my_stack_00.shape))

    my_stack_01 = torch.stack((T1, T2), dim=1)
    print("\nmy_stack_01:{} shape:{}".format(my_stack_01, my_stack_01.shape))

    my_stack_02 = torch.stack((T1, T2), dim=2)
    print("\nmy_stack_02:{} shape:{}".format(my_stack_02, my_stack_02.shape))

    my_stack_03 = torch.stack((T1, T2), dim=3)
    print("\nmy_stack_03:{} shape:{}".format(my_stack_03, my_stack_03.shape))

运行结果:

my_stack_00:tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]]) shape:torch.Size([2, 3, 3])

my_stack_01:tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]]) shape:torch.Size([3, 2, 3])

my_stack_02:tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]]) shape:torch.Size([3, 3, 2])
Traceback (most recent call last):
  File "C:\Users\DELL\PycharmProjects\hellepytorch\lesson-03.py", line 73, in <module>
dimshape
0[2, 3, 3]
1[3, 2, 3]
2[3, 3, 2]
3溢出报错

4. 总结

函数作用:
函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。

存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。

函数存在意义?

手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_len这个时间步]和–[2.张量属性[batch_size, output_size]]。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值