矩阵乘法

#include <iostream>
#include <random>
#include <stdio.h>
#include <chrono>

using namespace std;
void strassen_sub_multi_matrix(struct SubMatrix&, struct SubMatrix&, struct Matrix&);
struct Matrix {
    int row;
    int column;
    int** matrix;
};
struct SubMatrix {
    int row_start_pos = 0;
    int row_end_pos = 0;
    int column_start_pos = 0;
    int column_end_pos = 0;
    int** matrix;
};
//创建零矩阵
void create_void_matrix(struct Matrix& m,const int& row,const int& column) {
    m.row = row;
    m.column = column;
    m.matrix = new int*[row];
    for (int i = 0; i < row; ++i) {
        m.matrix[i] = new int[column];
    }
}
//创建随机矩阵
void create_random_matrix(struct Matrix& m, const int& row, const int& column) {
    m.row = row;
    m.column = column;
    m.matrix = new int*[row];
    for (int i = 0; i < m.row; ++i) {
        m.matrix[i] = new int[column];
    }
    auto seed = chrono::high_resolution_clock::now().time_since_epoch().count();
    static default_random_engine e(seed);
    static uniform_int_distribution<int> u(1, 100);
    for (int i = 0; i < m.row; ++i) {
        for (int j = 0; j < m.column; ++j) {
            m.matrix[i][j] = u(e);
        }
    }
}
//打印矩阵
void print_matrix(const struct Matrix& m) {
    for (int i = 0; i < m.row; ++i) {
        for (int j = 0; j < m.column; ++j) {
            printf("%d  ", m.matrix[i][j]);
        }
        printf("\n");
    }
    printf("\n");
}
//清理矩阵
void delete_matrix(struct Matrix& m) {
    for (int i = 0; i < m.row; ++i) {
        delete[] m.matrix[i];
    }
    delete[] m.matrix;
}
//标准矩阵乘法A(m,n) * B(n,p) (m,n,p任意正整数)
void normal_multi_matrix(struct Matrix& m1,struct Matrix& m2) {
    struct Matrix m;
    create_void_matrix(m, m1.row, m2.column);
    for (int i = 0; i < m1.row; ++i) {
        for (int j = 0; j < m1.row; ++j) {
            m.matrix[i][j] = 0;
            for (int k = 0; k < m2.row; ++k) {
                m.matrix[i][j] += m1.matrix[i][k] * m2.matrix[k][j];
            }
        }
    }
    print_matrix(m);
    delete_matrix(m);
}
//标准分治矩阵乘法A(n,n) * B(n,n) (n正整数且是偶数)
void normal_sub_multi_matrix(struct SubMatrix& subm1,struct SubMatrix& subm2,struct Matrix& m) {
    if((subm1.row_end_pos == subm1.row_start_pos) && (subm1.column_start_pos == subm1.column_end_pos) && (subm2.row_start_pos == subm2.row_end_pos) && (subm2.column_start_pos == subm2.column_end_pos)) {
        printf("a[%d][%d] = %d * b[%d][%d] = %d\n", subm1.row_start_pos,subm1.column_start_pos,subm1.matrix[subm1.row_start_pos][subm1.column_start_pos],subm2.row_start_pos,subm2.column_start_pos,subm2.matrix[subm2.row_start_pos][subm2.column_start_pos]);
        m.matrix[subm1.row_start_pos][subm2.column_start_pos] += subm1.matrix[subm1.row_start_pos][subm1.column_start_pos] * subm2.matrix[subm2.row_start_pos][subm2.column_start_pos];
        return;
    }
    int rsp1 = subm1.row_start_pos;
    int rsp2 = subm2.row_start_pos;
    int csp1 = subm1.column_start_pos;
    int csp2 = subm2.column_start_pos;
    int rep1 = subm1.row_end_pos;
    int rep2 = subm2.row_end_pos;
    int cep1 = subm1.column_end_pos;
    int cep2 = subm2.column_end_pos;
    // for C11 = A11 * B11 + A12 * B21
    //get A11
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B11
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //get A12
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B21
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //for C12 = A11 * B12 + A12 * B22
    //get A11
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B12
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //get A12
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B22
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //for C21 = A21 * B11 + A22 * B21
    //get A21
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B11
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //get A22
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B21
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //for C22 = A21 * B12 + A22 * B22
    //get A21
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B12
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    normal_sub_multi_matrix(subm1, subm2, m);
    //get A22
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B22
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    normal_sub_multi_matrix(subm1, subm2, m);
}
void normal_sub_multi_matrix(const struct Matrix& m1, const struct Matrix& m2) {
    struct SubMatrix subm1;
    struct SubMatrix subm2;
    subm1.row_start_pos = 0;
    subm1.column_start_pos = 0;
    subm2.row_start_pos = 0;
    subm2.column_start_pos = 0;
    subm1.row_end_pos = m1.row - 1;
    subm1.column_end_pos = m1.column - 1;
    subm2.row_end_pos = m2.row - 1;
    subm2.column_end_pos = m2.column - 1;
    subm1.matrix = m1.matrix;
    subm2.matrix = m2.matrix;
    struct Matrix m;
    create_void_matrix(m,m1.row,m2.column);
    for (int i = 0; i < m.row; ++i) {
        for (int j = 0; j < m.column; ++j) {
            m.matrix[i][j] = 0;
        }
    }
    //normal_sub_multi_matrix(subm1, subm2, m);
    strassen_sub_multi_matrix(subm1, subm2, m);
    print_matrix(m);
    delete_matrix(m);
}
//通过施拉特森矩阵乘法变型的(避免了矩阵加减法复制矩阵的开销,递归方法自下而上)
void strassen_sub_multi_matrix(struct SubMatrix& subm1, struct SubMatrix& subm2, struct Matrix& m) {
    if (subm1.row_end_pos - subm1.row_start_pos == 1 && subm1.column_end_pos - subm1.column_start_pos == 1 && subm2.row_end_pos - subm2.row_start_pos == 1 && subm2.column_end_pos - subm2.column_start_pos == 1) {
        int A11 = subm1.matrix[subm1.row_start_pos][subm1.column_start_pos];
        int A12 = subm1.matrix[subm1.row_start_pos][subm1.column_end_pos];
        int A21 = subm1.matrix[subm1.row_end_pos][subm1.column_start_pos];
        int A22 = subm1.matrix[subm1.row_end_pos][subm1.column_end_pos];
        int B11 = subm2.matrix[subm2.row_start_pos][subm2.column_start_pos];
        int B12 = subm2.matrix[subm2.row_start_pos][subm2.column_end_pos];
        int B21 = subm2.matrix[subm2.row_end_pos][subm2.column_start_pos];
        int B22 = subm2.matrix[subm2.row_end_pos][subm2.column_end_pos];
        int M1 = (A11 + A22) * (B11 + B22);
        int M2 = (A21 + A22) * B11;
        int M3 = A11 * (B12 - B22);
        int M4 = A22 * (B21 - B11);
        int M5 = (A11 + A12) * B22;
        int M6 = (A21 - A11) * (B11 + B12);
        int M7 = (A12 - A22) * (B21 + B22);
        //for C11
        m.matrix[subm1.row_start_pos][subm2.column_start_pos] += M1 + M4 - M5 + M7;
        //for C12
        m.matrix[subm1.row_start_pos][subm2.column_end_pos] += M3 + M5;
        //for C21
        m.matrix[subm1.row_end_pos][subm2.column_start_pos] += M2 + M4;
        //for C22
        m.matrix[subm1.row_end_pos][subm2.column_end_pos] += M1 - M2 + M3 + M6;
        return;
    }
    int rsp1 = subm1.row_start_pos;
    int rep1 = subm1.row_end_pos;
    int csp1 = subm1.column_start_pos;
    int cep1 = subm1.column_end_pos;
    int rsp2 = subm2.row_start_pos;
    int rep2 = subm2.row_end_pos;
    int csp2 = subm2.column_start_pos;
    int cep2 = subm2.column_end_pos;
    //C11 = A11 * B11 + A12 * B21
    //get A11 * B11
    //get A11
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B11
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //get A12 * B21
    //get A12
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B21
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //C12 = A11 * B12 + A12 * B22
    //get A11 * B12
    //get A11
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B12
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //get A12 * B22
    //get A12
    subm1.row_start_pos = rsp1;
    subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B22
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //for C21 = A21 * B11 + A22 * B21
    //get A21 * B11
    //get A21
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B11
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //get A22
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B21
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2;
    subm2.column_end_pos = csp2 + (cep2 - csp2) / 2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //for C22 = A21 * B12 + A22 * B22
    //get A21 * B12
    //get A21
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1;
    subm1.column_end_pos = csp1 + (cep1 - csp1) / 2;
    //get B12
    subm2.row_start_pos = rsp2;
    subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    strassen_sub_multi_matrix(subm1, subm2, m);
    //get A22 * B22
    //get A22
    subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1;
    subm1.row_end_pos = rep1;
    subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1;
    subm1.column_end_pos = cep1;
    //get B22
    subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1;
    subm2.row_end_pos = rep2;
    subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1;
    subm2.column_end_pos = cep2;
    strassen_sub_multi_matrix(subm1, subm2, m);
}
int main(int argc,char** argv) {
    struct Matrix m1;
    struct Matrix m2;
    create_random_matrix(m1,2,2);
    create_random_matrix(m2,2,2);
    print_matrix(m1);
    print_matrix(m2);
    normal_multi_matrix(m1, m2);
    normal_sub_multi_matrix(m1, m2);
    delete_matrix(m1);
    delete_matrix(m2);
    system("pause");
    return 0;
}

这里写图片描述

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值