embedding = nn.Embedding(10, 3)声明一个Embedding层,最大的embeddings个数是10,维数为3。Embedding.weight会从标准正态分布中初始化成大小为(num_embeddings, embedding_dim)的矩阵,input中的标号表示从矩阵对应行获取权重来表示单词。所有的input变量都小于10,若大于10,则会报错。
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
print(embedding(input))
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])

本文介绍PyTorch中Embedding层的基本用法,包括声明、初始化及输入输出等关键步骤。通过实例演示了如何使用Embedding层将索引转化为固定长度的向量。

1919

被折叠的 条评论
为什么被折叠?



