torch.mm
torch.mm是两个矩阵相乘,即两个二维的张量相乘
如下面的例子
mat1 = torch.randn(2,3)
print("mat1=", mat1)
mat2 = torch.randn(3,2)
print("mat2=", mat2)
mat3 = torch.mm(mat1, mat2)
print("mat3=", mat3)

但是如果维度超过二维,则会报错。RuntimeError: self must be a matrix
torch.bmm
它其实就是加了一维batch,所以第一位为batch,并且要两个Tensor的batch相等。
第二维和第三维就是mm运算了,同上了。
示例代码如下:
mat1 = torch.randn(10, 2, 4)
# print("mat1=", mat1)
mat2 = torch.randn(10, 4, 1)
# print("mat2=", mat2)
mat3 = torch.matmul(mat1, mat2)
print("mat3=", mat3, mat3.shape)

torch.matmul
torch.mm仅仅是供矩阵相乘使用,使用范围较为狭窄。
而torch.matmul使用的场合就比较多了。
如官方文档所介绍,有如下几种:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
如果两个tensor都是一维的,则为点乘运算,即每个元素对应相乘求和
如下:
mat1 = torch.Tensor([1,2])
print("mat1=", mat1)
mat2 = torch.Tensor([1<

本文详细解析了torch库中矩阵乘法函数torch.mm, torch.bmm和torch.matmul的用法区别,涵盖了不同维度的矩阵操作,从一维到多维的扩展,以及与numpy的对比。重点讲解了广播机制和batched matrix multiply的应用场景。

1万+

被折叠的 条评论
为什么被折叠?



