在使用pytorch时,我们经常会看到tensor1[tensor2]这样的代码,而不同的tensor2会导致不同的索引方式,本文详细介绍使用tensor作为索引的原理。
首先构造tensor1:
>>> a = torch.randn(10, 5)
>>> a
tensor([[-0.9434, -2.8668, -0.4331, -1.6842, -0.1823],
[-1.4545, -1.0065, 0.3228, 0.7457, 0.6225],
[ 0.5884, -1.4933, -0.4641, 0.4596, -1.5091],
[-0.4232, 1.0866, -0.3649, 1.4429, 1.7786],
[-0.3113, -0.3810, -1.0637, 0.5268, 1.0615],
[ 0.8262, -0.1033, 0.2941, -0.0158, 0.5710],
[ 0.8346, 0.6172, 0.0416, -0.6910, 1.1025],
[-1.1312, 0.0694, -0.6494, -2.1948, 2.2036],
[-0.5208, -0.7442, -0.5526, 1.4329, -0.0613],
[-0.5576, -0.5130, -0.1988, 0.3616, -0.0838]])
>>> a.shape
torch.Size([10, 5])
索引为torch.uint8类型
某些pytorch版本(0.4.0之前,1.2.0之后)有torch.bool类型,同样适用于这种情况。
>>> b = torch.tensor([[1,4,0],[2,0,3]], dtype=torch.uint8)
>>> b
tensor([[1, 4, 0],
[2, 0, 3]], dtype=torch.uint8)
>>> b.shape
torch.Size([2, 3])
>>> a[b]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: The shape of the mask [2, 3] at index 0 does not match the shape of the indexed tensor [10, 5] at index 0
>>> b = torch.tensor([[1,0],[2,0],[0,2],[0,9],[11,1],[1,0],[2,0],[0,2],[0,9],[11,1]], dtype=torch.uint8)
>>> b.shape
torch.Size([10, 2])
>>> a[b]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: The shape of the mask [10, 2] at index 1 does not match the shape of the indexed tensor [10, 5] at index 1
>>> b = torch.tensor([2,9,0,1,2,3,4,0,5,0], dtype=torch.uint8)
>>> b.shape
torch.Size([10])
>>> a[b]
tensor([[-0.9434, -2.8668, -0.4331, -1.6842, -0.1823],
[-1.4545, -1.0065, 0.3228, 0.7457, 0.6225],
[-0.4232, 1.0866, -0.3649, 1.4429, 1.7786],
[-0.3113, -0.3810, -1.0637, 0.5268, 1.0615],
[ 0.8262, -0.1033, 0.2941, -0.0158, 0.5710],
[ 0.8346, 0.6172, 0.0416, -0.6910, 1.1025],
[-0.5208, -0.7442, -0.5526, 1.4329, -0.0613]])
>>> a[b].shape
torch.Size([7, 5])
'''
首先我们可以看到,tensor2类型为torch.uint8时,tensor2的所有维度大小
必须跟tensor1的对应维度大小相同(tensor2的维度可以小于tensor1的维度)
其次,这种情况下tensor2相当于一个mask,用于取出tensor1中特定位置的值,
这个特定位置就是tensor2中不为0的位置
'''
索引为torch.long类型
这种情况下tensor2中的元素相当于位置,但索引的方式也有两种:
一维tensor作为索引
这种情况下跟numpy中的花式索引是相同的,如果一维张量的数量小于tensor的维度数量,那么未指定的维度就相当于[:],但是要注意,这些一维张量的shape必须相同。
>>> a = torch.randn(5, 3)
>>> a
tensor([[ 1.1462, 0.4856, -0.0858],
[-1.2447, -0.9900, 1.9999],
[-1.5310, -0.3016, 0.7738],
[-1.3481, -0.3005, 0.8936],
[ 1.1273, 1.8933, -0.0448]])
>>> a[torch.tensor([4,2,3,3]), torch.tensor([0,1,2,2])]
tensor([ 1.1273, -0.3016, 0.8936, 0.8936])
'''
这就相当于选择了a[4,0],a[2,1],a[3,2],a[3,2]
每个一维张量中的值代表在对应维度中选择的位置
'''
>>> a[torch.tensor([2,4,2])]
tensor([[-1.5310, -0.3016, 0.7738],
[ 1.1273, 1.8933, -0.0448],
[-1.5310, -0.3016, 0.7738]])
'''
这就相当于a[torch.tensor([2,4,2]), :],选择了a[2,:],a[4,:],a[2,:]
'''
>>> a[torch.tensor([4,2,3]), torch.tensor([0,1])]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
多维tensor作为索引
这种情况下tensor2中的每个元素都用于选择tensor1第0维中的位置,以二维张量为例:
tensor1[tensor2].shape = torch.Size([tensor2.shape[0], tensor2.shape[1], tensor1.shape[1])
>>> b = torch.tensor([[2,4,3], [0,1,2],[3,0,1]])
>>>> b.shape
torch.Size([3, 3])
>>> a[b]
tensor([[[-1.5310, -0.3016, 0.7738],
[ 1.1273, 1.8933, -0.0448],
[-1.3481, -0.3005, 0.8936]],
[[ 1.1462, 0.4856, -0.0858],
[-1.2447, -0.9900, 1.9999],
[-1.5310, -0.3016, 0.7738]],
[[-1.3481, -0.3005, 0.8936],
[ 1.1462, 0.4856, -0.0858],
[-1.2447, -0.9900, 1.9999]]])
>>> a[b].shape
torch.Size([3, 3, 3])
当tensor2为多维向量时:
>>> b = torch.tensor([[[2,4], [0,1],[3,0]],[[2,4], [0,1],[3,0]]])
>>> b.shape
torch.Size([2, 3, 2])
>>> a[b].shape
torch.Size([2, 3, 2, 3])
可以看到,tensor2的维度数量可以大于tensor1的维度数量。
当tensor1为多维张量时:
>>> a = torch.randn(5,3,3)
>>> b = torch.tensor([[[2,4], [0,1],[3,0],[2,1]]])
>>> b.shape
torch.Size([1, 4, 2])
>>> a[b].shape
torch.Size([1, 4, 2, 3, 3])
此时依然适用,最终结果的shape就是在tensor2的shape后面加上tensor1[0]的shape。
本文详细解析PyTorch中如何使用Tensor作为索引来操作数据。介绍了torch.uint8和torch.long类型的索引,包括一维和多维Tensor的索引应用场景,解释了不同索引方式对结果的影响。

658

被折叠的 条评论
为什么被折叠?



