CANN/ops-nn转置批处理矩阵乘法算子

TransposeBatchMatMul

【免费下载链接】ops-nn 本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。 【免费下载链接】ops-nn 项目地址: https://gitcode.com/cann/ops-nn

产品支持情况

产品是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2推理产品×
Atlas 推理系列产品×
Atlas 训练系列产品×
Kirin X90 处理器系列产品
Kirin 9030 处理器系列产品

功能说明

  • 算子功能:完成张量x1与张量x2的矩阵乘计算。仅支持三维的Tensor传入。Tensor支持转置,转置序列根据传入的序列进行变更。permX1代表张量x1的转置序列,支持[0,1,2]、[1,0,2],permX2代表张量x2的转置序列[0,1,2],permY表示矩阵乘输出矩阵的转置序列,当前仅支持[1,0,2],序列值为0的是batch维度,其余两个维度做矩阵乘法。scale表示输出矩阵的量化系数,可在输入为FLOAT16且输出为INT8时使能,详细约束条件可见约束说明或者aclnnTransposeBatchMatMul调用说明文档。

  • 示例:

    • x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor等于1时,计算输出out的shape是(M, B, N)。
    • x1的shape是(B, M, K),x2的shape是(B, K, N),scale不为None,batchSplitFactor等于1时,计算输出out的shape是(M, 1, B * N)。
    • x1的shape是(B, M, K),x2的shape是(B, K, N),scale为None,batchSplitFactor大于1时,计算输出out的shape是(batchSplitFactor, M, B * N / batchSplitFactor)。

参数说明

参数名输入/输出/属性描述数据类型数据格式
x1输入矩阵乘运算中的左矩阵。FLOAT32, FLOAT16, BF16ND
x2输入矩阵乘运算中的右矩阵。FLOAT32, FLOAT16, BF16ND
bias输入矩阵乘运算后累加的偏置。FLOAT32, FLOAT16, BF16ND
scale输入量化参数的缩放因子。INT64, UINT64ND
permX1输入表示矩阵乘的第一个矩阵的转置序列。INT64-
permX2输入表示矩阵乘的第二个矩阵的转置序列。INT64-
permY输入表示矩阵乘输出矩阵的转置序列。INT64-
cubeMathType输入指定Cube单元的计算逻辑。INT8-
batchSplitFactor输入用于指定矩阵乘输出矩阵中B维的切分大小。INT32-
y输出矩阵乘运算的计算结果。FLOAT32, FLOAT16, BF16, INT8ND
  • Kirin X90/Kirin 9030处理器系列产品:不支持BFLOAT16。

约束说明

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品 、 Atlas A3 训练系列产品/Atlas A3 推理系列产品 :
    • 不支持空tensor。
    • 支持非连续tensor。
    • B的取值范围为[1, 65536),N的取值范围为[1, 65536)。
    • 当x1的输入shape为(B, M, K)时,K <= 65535;当x1的输入shape为(M, B, K)时,B * K <= 65535。
    • 当scale不为空时,batchSplitFactor只能等于1,B与N的乘积小于65536, 且仅支持输入为FLOAT16和输出为INT8的类型推导。
  • Ascend 950PR/Ascend 950DT :
    • 当scale不为空时,batchSplitFactor只能等于1,且仅支持输入为FLOAT16和输出为INT8的类型推导。
    • bias为预留参数,当前暂不支持。

调用说明

调用方式样例代码说明
aclnn接口test_aclnn_transpose_batch_mat_mul通过
aclnnTransposeBatchMatMul
等方式调用TransposeBatchMatMul算子。

【免费下载链接】ops-nn 本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。 【免费下载链接】ops-nn 项目地址: https://gitcode.com/cann/ops-nn

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值