torch.mul、torch.mm、torch.bmm、torch.matmul的区别

本文详细解析了PyTorch中四种主要的矩阵运算函数:torch.mul、torch.mm、torch.bmm和torch.matmul的功能及使用场景。从对位相乘到不同维度的矩阵乘法,再到批量计算和广播机制,全面覆盖了PyTorch在矩阵运算方面的强大能力。
Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

torch.mul

torch.mul(input, other, out=None)

功能

对位相乘,可以广播

该函数能处理两种情况

  1. input是矩阵/向量,other是标量
    这个时候是就是input的所有元素乘上other
  2. input是矩阵/向量,other是矩阵/向量
    这时 outi=inputi×otheriout_i = input_i \times other_iouti=inputi×otheri,对位相乘,如果两个都是向量,则可以广播的

例子

  1. input和other的size相同的对位相乘

    a: tensor([[ 1.8351,  2.1536],
        [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
        [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
        [-0.4749, -2.9093]])
    
  2. 两个向量的广播

    a: tensor([[ 1.8351,  2.1536],
            [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
            [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
            [-0.4749, -2.9093]])
    

torch.mm

torch.mm(input, mat2, out=None)

解决的问题

处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul

例子

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

torch.bmm(input, mat2, out=None)

看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播


torch.matmul

torch.matmul(input, other, out=None)

功能
适用性最多的,能处理batch、广播的矩阵:

  1. 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
  2. 如果第一个是二维,第二个是一维,就是矩阵乘向量
  3. 带有batch的情况,可保留batch计算
  4. 维度不同时,可先广播,再batch计算

例子

  1. vector x vector

    tensor1 = torch.randn(3)
    tensor2 = torch.randn(3)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([])
    
  2. matrix x vector

    tensor1 = torch.randn(3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([3])
    
  3. batched matrix x broadcasted vecto

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3])
    
  4. batched matrix x batched matrix

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(10, 4, 5)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3, 5])
    

总结

对位相乘用torch.mul,二维矩阵乘法用torch.mm,batch二维矩阵用torch.bmm,batch、广播用torch.matmul

您可能感兴趣的与本文相关的镜像

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen
文本生成
Qwen3

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值