前言

在这里插入图片描述
mask是深度学习里面常用的操作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下

1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with

举个例子

import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])

2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with

举个例子

import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])

3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask

source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor

Parameters
mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from

举个例子

import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐