matmul/mm 函数用法介绍

文章介绍了PyTorch中的torch.mm和torch.matmul函数,它们用于矩阵乘法。torch.mm仅适用于二维矩阵,而torch.matmul支持二维及高维矩阵的乘法。在高维矩阵乘法中,多于二维的部分会根据规则进行维度匹配或扩展,以完成运算。例如,3维矩阵相乘时,如果维度相同则直接提出来,若其中一个维度为1,则可扩展匹配,否则会报错。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

介绍torch.matmul之前先介绍torch.mm函数, mm和matmul都是torch中矩阵乘法函数,mm只能作用于二维矩阵,matmul可以作用于二维也能作用于高维矩阵

mm函数使用

x = torch.rand(4, 9)
y = torch.rand(9, 8)
print(torch.mm(x,y).shape)

torch.Size([4, 8])

matmul函数使用

  • 1 二维乘二维,结果和mm函数一样
x = torch.rand(4,9)
y = torch.rand(9,8)

print(torch.matmul(x,y).shape)
torch.Size([4, 8])
  • 2 高维乘法 3 维 乘 2 维
    将x的第0维提出来,剩下的就是二维矩阵乘法得到 9,4,8
x = torch.rand(9,4,9)
y = torch.rand(9,8)
 
print(torch.matmul(x,y).shape)

torch.Size([9, 4, 8])
  • 3 高维矩阵乘法 3维 乘 3维
    两种情况
    1)x和y的0维一样直接提出来,剩下的是二维矩阵乘法结果是 9,5,6
    2)x和y的第0维不一样,如果是有一个第0维是1则可以直接扩展成跟另一个矩阵的维度一样,然后直接提出来,剩下的二维矩阵乘法,下方x的矩阵第0维是1,y的是9,直接把x矩阵第0维扩展成9,即可跟下方第一个操作相同
x = torch.rand(9, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)

torch.Size([9,5,6])

x = torch.rand(1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)

torch.Size([9,5,6])
torch.Size([9,5,6])
    1. 高维矩阵乘法 4维 乘 3 维
      根据上方总结的规则,下方同样做法,多余的一维或几维,直接提出来,剩下的同维度矩阵直接计算,如果是1就扩展成与之相对另一个矩阵的相同的数,如果不同也不为1,就直接报错
x = torch.rand(6, 9, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,5,6])

x = torch.rand(6, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,5,6])

x = torch.rand(6, 9, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,9,5,6])

x = torch.rand(6, 9,8,9, 1, 5, 8)
y = torch.rand(9, 8, 6)
print(torch.matmul(x,y).shape)
torch.Size([6,9,8,9,9,5,6])

  

 

参考资料

torch.mm()&torch.matmul()
torch.matmul()用法介绍
pytorch官方文档
torch中点积,叉积和卷积

/* * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * * @brief copy cube out ut * */ #include <gtest/gtest.h> #include "kernel_operator.h" #include "lib/matmul/tiling.h" #include "impl/matmul/utils/matmul_param.h" #include "impl/matmul/policy/matmul_policy.h" #include "impl/matmul/policy/matmul_private_modules.h" #include "impl/matmul/utils/matmul_call_back.h" using namespace std; using namespace AscendC; namespace { template <typename T> const LocalTensor<T> EMPTY_TENSOR; template <const auto& MM_CFG, typename IMPL, typename A_TYPE, typename B_TYPE, typename C_TYPE, typename BIAS_TYPE> class CustomMatmulPolicy : public Impl::Detail::MatmulPolicy<MM_CFG, IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> { }; template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB = MatmulCallBackFunc<nullptr, nullptr, nullptr>, MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> class MatmulImpl : MATMUL_IMPORT_MODULE(CubeOutBuffer) , MATMUL_IMPORT_MODULE(CopyCubeOut) , MATMUL_IMPORT_MODULE_PRIVATE(MatmulShapeInfo) , MATMUL_IMPORT_MODULE_PRIVATE(MatmulShapeTiling) { MATMUL_ALLOW_USING(CubeOutBuffer); MATMUL_ALLOW_USING(CopyCubeOut); MATMUL_ALLOW_USING_PRIVATE(MatmulShapeInfo); MATMUL_ALLOW_USING_PRIVATE(MatmulShapeTiling); using SrcT = typename A_TYPE::T; public: using VAR_PARAMS = typename Impl::Detail::MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(MM_CFG)>::PARAMS; using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>; using CallBack = MM_CB; CubeOutBuffer::Init; CubeOutBuffer::AllocTensor; CubeOutBuffer::GetTensor; CubeOutBuffer::EnQue; CubeOutBuffer::DeQue; CubeOutBuffer::FreeTensor; CopyCubeOut::Copy; MATMUL_USE_MODULE(MatmulShapeTiling); MATMUL_USE_MODULE(MatmulShapeInfo); MatmulImpl() {} VAR_PARAMS& GetVar() { return var; } void InitVar(const TCubeTiling &tiling) { MATMUL_MODULE(MatmulShapeTiling)->SetTiling(&tiling); var.tpipe_ = &pipe; } void SetRuntimeParams(int32_t m, int32_t n) { MATMUL_MODULE(MatmulShapeInfo)->SetOrgShape(m, n, n, n, 0); } uint32_t GetBufferSize() { const auto tiling = MATMUL_MODULE(MatmulShapeTiling)->GetTiling(); return tiling.GetBaseM() * tiling.GetBaseN(); } private: TPipe pipe; VAR_PARAMS var; }; } class TestCopyCubeOut : public testing::Test { using A_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float, false>; using B_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float, false>; using C_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>; using C_TYPE_NZ = MatmulType<AscendC::TPosition::GM, CubeFormat::NZ, float>; using BIAS_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>; using A_T = A_TYPE::T; protected: void SetUp() {} void TearDown() {} private: MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, CFG_NORM, MatmulCallBackFunc<nullptr, nullptr, nullptr>, CustomMatmulPolicy> mm; MatmulImpl<A_TYPE, B_TYPE, C_TYPE_NZ, BIAS_TYPE, CFG_NORM, MatmulCallBackFunc<nullptr, nullptr, nullptr>, CustomMatmulPolicy> mm1; }; TEST_F(TestCopyCubeOut, Copy_NZ_From_ND) { TCubeTiling tiling; tiling.M = 16; tiling.N = 16; tiling.Ka = 16; tiling.Kb = 16; tiling.singleCoreM = 16; tiling.singleCoreN = 16; tiling.singleCoreK = 16; tiling.baseM = 16; tiling.baseN = 16; tiling.baseK = 16; tiling.depthA1 = 1; tiling.depthB1 = 1; tiling.stepM = 1; tiling.stepN = 1; tiling.stepKa = 1; tiling.stepKb = 1; tiling.isBias = false; tiling.iterateOrder = 0; mm.InitVar(tiling); mm.SetRuntimeParams(16, 16); mm.Init(mm.GetBufferSize()); mm.AllocTensor(); auto co1Local = mm.GetTensor(); mm.EnQue(co1Local); mm.DeQue(); const int data_size = 16 * 16; uint8_t aGM[data_size * sizeof(A_T)] = {0}; GlobalTensor<A_T> dst; dst.SetGlobalBuffer(reinterpret_cast<__gm__ A_T*>(aGM), data_size); mm.template Copy<false>(dst, co1Local, 0, 0, 16, 16, 1, 1); mm.FreeTensor(co1Local); } TEST_F(TestCopyCubeOut, Copy_NZ_From_NZ) { TCubeTiling tiling; tiling.M = 16; tiling.N = 16; tiling.Ka = 16; tiling.Kb = 16; tiling.singleCoreM = 16; tiling.singleCoreN = 16; tiling.singleCoreK = 16; tiling.baseM = 16; tiling.baseN = 16; tiling.baseK = 16; tiling.depthA1 = 1; tiling.depthB1 = 1; tiling.stepM = 1; tiling.stepN = 1; tiling.stepKa = 1; tiling.stepKb = 1; tiling.isBias = false; tiling.iterateOrder = 0; mm1.InitVar(tiling); mm1.SetRuntimeParams(16, 16); mm1.Init(mm1.GetBufferSize()); mm1.AllocTensor(); auto co1Local = mm1.GetTensor(); mm1.EnQue(co1Local); mm1.DeQue(); const int data_size = 16 * 16; uint8_t aGM[data_size * sizeof(A_T)] = {0}; GlobalTensor<A_T> dst; dst.SetGlobalBuffer(reinterpret_cast<__gm__ A_T*>(aGM), data_size); mm1.template Copy<false>(dst, co1Local, 0, 0, 16, 16, 1, 1); mm1.FreeTensor(co1Local); } 代码详细解析
最新发布
07-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

telllong

你的鼓励是我创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值