pytorch使用tensor作为tensor的索引

本文详细解析PyTorch中如何使用Tensor作为索引来操作数据。介绍了torch.uint8和torch.long类型的索引,包括一维和多维Tensor的索引应用场景,解释了不同索引方式对结果的影响。

在使用pytorch时,我们经常会看到tensor1[tensor2]这样的代码,而不同的tensor2会导致不同的索引方式,本文详细介绍使用tensor作为索引的原理。
首先构造tensor1

>>> a = torch.randn(10, 5)
>>> a
tensor([[-0.9434, -2.8668, -0.4331, -1.6842, -0.1823],
        [-1.4545, -1.0065,  0.3228,  0.7457,  0.6225],
        [ 0.5884, -1.4933, -0.4641,  0.4596, -1.5091],
        [-0.4232,  1.0866, -0.3649,  1.4429,  1.7786],
        [-0.3113, -0.3810, -1.0637,  0.5268,  1.0615],
        [ 0.8262, -0.1033,  0.2941, -0.0158,  0.5710],
        [ 0.8346,  0.6172,  0.0416, -0.6910,  1.1025],
        [-1.1312,  0.0694, -0.6494, -2.1948,  2.2036],
        [-0.5208, -0.7442, -0.5526,  1.4329, -0.0613],
        [-0.5576, -0.5130, -0.1988,  0.3616, -0.0838]])
>>> a.shape
torch.Size([10, 5])

索引为torch.uint8类型

某些pytorch版本(0.4.0之前,1.2.0之后)有torch.bool类型,同样适用于这种情况。

>>> b = torch.tensor([[1,4,0],[2,0,3]], dtype=torch.uint8)
>>> b
tensor([[1, 4, 0],
        [2, 0, 3]], dtype=torch.uint8)
>>> b.shape
torch.Size([2, 3])
>>> a[b]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: The shape of the mask [2, 3] at index 0 does not match the shape of the indexed tensor [10, 5] at index 0
>>> b = torch.tensor([[1,0],[2,0],[0,2],[0,9],[11,1],[1,0],[2,0],[0,2],[0,9],[11,1]], dtype=torch.uint8)
>>> b.shape
torch.Size([10, 2])
>>> a[b]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: The shape of the mask [10, 2] at index 1 does not match the shape of the indexed tensor [10, 5] at index 1


>>> b = torch.tensor([2,9,0,1,2,3,4,0,5,0], dtype=torch.uint8)
>>> b.shape
torch.Size([10])
>>> a[b]
tensor([[-0.9434, -2.8668, -0.4331, -1.6842, -0.1823],
        [-1.4545, -1.0065,  0.3228,  0.7457,  0.6225],
        [-0.4232,  1.0866, -0.3649,  1.4429,  1.7786],
        [-0.3113, -0.3810, -1.0637,  0.5268,  1.0615],
        [ 0.8262, -0.1033,  0.2941, -0.0158,  0.5710],
        [ 0.8346,  0.6172,  0.0416, -0.6910,  1.1025],
        [-0.5208, -0.7442, -0.5526,  1.4329, -0.0613]])
>>> a[b].shape
torch.Size([7, 5])

'''
首先我们可以看到,tensor2类型为torch.uint8时,tensor2的所有维度大小
必须跟tensor1的对应维度大小相同(tensor2的维度可以小于tensor1的维度)
其次,这种情况下tensor2相当于一个mask,用于取出tensor1中特定位置的值,
这个特定位置就是tensor2中不为0的位置
'''

索引为torch.long类型

这种情况下tensor2中的元素相当于位置,但索引的方式也有两种:

一维tensor作为索引

这种情况下跟numpy中的花式索引是相同的,如果一维张量的数量小于tensor的维度数量,那么未指定的维度就相当于[:],但是要注意,这些一维张量的shape必须相同。

>>> a = torch.randn(5, 3)
>>> a
tensor([[ 1.1462,  0.4856, -0.0858],
        [-1.2447, -0.9900,  1.9999],
        [-1.5310, -0.3016,  0.7738],
        [-1.3481, -0.3005,  0.8936],
        [ 1.1273,  1.8933, -0.0448]])
>>> a[torch.tensor([4,2,3,3]), torch.tensor([0,1,2,2])]
tensor([ 1.1273, -0.3016,  0.8936,  0.8936])
'''
这就相当于选择了a[4,0],a[2,1],a[3,2],a[3,2]
每个一维张量中的值代表在对应维度中选择的位置
'''
>>> a[torch.tensor([2,4,2])]
tensor([[-1.5310, -0.3016,  0.7738],
        [ 1.1273,  1.8933, -0.0448],
        [-1.5310, -0.3016,  0.7738]])
'''
这就相当于a[torch.tensor([2,4,2]), :],选择了a[2,:],a[4,:],a[2,:]
'''
>>> a[torch.tensor([4,2,3]), torch.tensor([0,1])]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

多维tensor作为索引

这种情况下tensor2中的每个元素都用于选择tensor1第0维中的位置,以二维张量为例:
tensor1[tensor2].shape = torch.Size([tensor2.shape[0], tensor2.shape[1], tensor1.shape[1])

>>> b = torch.tensor([[2,4,3], [0,1,2],[3,0,1]])
>>>> b.shape
torch.Size([3, 3])
>>> a[b]
tensor([[[-1.5310, -0.3016,  0.7738],
         [ 1.1273,  1.8933, -0.0448],
         [-1.3481, -0.3005,  0.8936]],

        [[ 1.1462,  0.4856, -0.0858],
         [-1.2447, -0.9900,  1.9999],
         [-1.5310, -0.3016,  0.7738]],

        [[-1.3481, -0.3005,  0.8936],
         [ 1.1462,  0.4856, -0.0858],
         [-1.2447, -0.9900,  1.9999]]])
>>> a[b].shape
torch.Size([3, 3, 3])

tensor2为多维向量时:

>>> b = torch.tensor([[[2,4], [0,1],[3,0]],[[2,4], [0,1],[3,0]]])
>>> b.shape
torch.Size([2, 3, 2])
>>> a[b].shape
torch.Size([2, 3, 2, 3])

可以看到,tensor2的维度数量可以大于tensor1的维度数量。

tensor1为多维张量时:

>>> a = torch.randn(5,3,3)
>>> b = torch.tensor([[[2,4], [0,1],[3,0],[2,1]]])
>>> b.shape
torch.Size([1, 4, 2])
>>> a[b].shape
torch.Size([1, 4, 2, 3, 3])

此时依然适用,最终结果的shape就是在tensor2的shape后面加上tensor1[0]的shape。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值