#include <iostream>
#include <random>
#include <stdio.h>
#include <chrono>
using namespace std
void strassen_sub_multi_matrix(struct SubMatrix&, struct SubMatrix&, struct Matrix&)
struct Matrix {
int row
int column
int** matrix
}
struct SubMatrix {
int row_start_pos = 0
int row_end_pos = 0
int column_start_pos = 0
int column_end_pos = 0
int** matrix
}
//创建零矩阵
void create_void_matrix(struct Matrix& m,const int& row,const int& column) {
m.row = row
m.column = column
m.matrix = new int*[row]
for (int i = 0
m.matrix[i] = new int[column]
}
}
//创建随机矩阵
void create_random_matrix(struct Matrix& m, const int& row, const int& column) {
m.row = row
m.column = column
m.matrix = new int*[row]
for (int i = 0
m.matrix[i] = new int[column]
}
auto seed = chrono::high_resolution_clock::now().time_since_epoch().count()
static default_random_engine e(seed)
static uniform_int_distribution<int> u(1, 100)
for (int i = 0
for (int j = 0
m.matrix[i][j] = u(e)
}
}
}
//打印矩阵
void print_matrix(const struct Matrix& m) {
for (int i = 0
for (int j = 0
printf("%d ", m.matrix[i][j])
}
printf("\n")
}
printf("\n")
}
//清理矩阵
void delete_matrix(struct Matrix& m) {
for (int i = 0
delete[] m.matrix[i]
}
delete[] m.matrix
}
//标准矩阵乘法A(m,n) * B(n,p) (m,n,p任意正整数)
void normal_multi_matrix(struct Matrix& m1,struct Matrix& m2) {
struct Matrix m
create_void_matrix(m, m1.row, m2.column)
for (int i = 0
for (int j = 0
m.matrix[i][j] = 0
for (int k = 0
m.matrix[i][j] += m1.matrix[i][k] * m2.matrix[k][j]
}
}
}
print_matrix(m)
delete_matrix(m)
}
//标准分治矩阵乘法A(n,n) * B(n,n) (n正整数且是偶数)
void normal_sub_multi_matrix(struct SubMatrix& subm1,struct SubMatrix& subm2,struct Matrix& m) {
if((subm1.row_end_pos == subm1.row_start_pos) && (subm1.column_start_pos == subm1.column_end_pos) && (subm2.row_start_pos == subm2.row_end_pos) && (subm2.column_start_pos == subm2.column_end_pos)) {
printf("a[%d][%d] = %d * b[%d][%d] = %d\n", subm1.row_start_pos,subm1.column_start_pos,subm1.matrix[subm1.row_start_pos][subm1.column_start_pos],subm2.row_start_pos,subm2.column_start_pos,subm2.matrix[subm2.row_start_pos][subm2.column_start_pos])
m.matrix[subm1.row_start_pos][subm2.column_start_pos] += subm1.matrix[subm1.row_start_pos][subm1.column_start_pos] * subm2.matrix[subm2.row_start_pos][subm2.column_start_pos]
return
}
int rsp1 = subm1.row_start_pos
int rsp2 = subm2.row_start_pos
int csp1 = subm1.column_start_pos
int csp2 = subm2.column_start_pos
int rep1 = subm1.row_end_pos
int rep2 = subm2.row_end_pos
int cep1 = subm1.column_end_pos
int cep2 = subm2.column_end_pos
// for C11 = A11 * B11 + A12 * B21
//get A11
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B11
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
normal_sub_multi_matrix(subm1, subm2, m)
//get A12
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B21
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
normal_sub_multi_matrix(subm1, subm2, m)
//for C12 = A11 * B12 + A12 * B22
//get A11
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B12
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
normal_sub_multi_matrix(subm1, subm2, m)
//get A12
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B22
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
normal_sub_multi_matrix(subm1, subm2, m)
//for C21 = A21 * B11 + A22 * B21
//get A21
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B11
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
normal_sub_multi_matrix(subm1, subm2, m)
//get A22
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B21
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
normal_sub_multi_matrix(subm1, subm2, m)
//for C22 = A21 * B12 + A22 * B22
//get A21
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B12
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
normal_sub_multi_matrix(subm1, subm2, m)
//get A22
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B22
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
normal_sub_multi_matrix(subm1, subm2, m)
}
void normal_sub_multi_matrix(const struct Matrix& m1, const struct Matrix& m2) {
struct SubMatrix subm1
struct SubMatrix subm2
subm1.row_start_pos = 0
subm1.column_start_pos = 0
subm2.row_start_pos = 0
subm2.column_start_pos = 0
subm1.row_end_pos = m1.row - 1
subm1.column_end_pos = m1.column - 1
subm2.row_end_pos = m2.row - 1
subm2.column_end_pos = m2.column - 1
subm1.matrix = m1.matrix
subm2.matrix = m2.matrix
struct Matrix m
create_void_matrix(m,m1.row,m2.column)
for (int i = 0
for (int j = 0
m.matrix[i][j] = 0
}
}
//normal_sub_multi_matrix(subm1, subm2, m)
strassen_sub_multi_matrix(subm1, subm2, m)
print_matrix(m)
delete_matrix(m)
}
//通过施拉特森矩阵乘法变型的(避免了矩阵加减法复制矩阵的开销,递归方法自下而上)
void strassen_sub_multi_matrix(struct SubMatrix& subm1, struct SubMatrix& subm2, struct Matrix& m) {
if (subm1.row_end_pos - subm1.row_start_pos == 1 && subm1.column_end_pos - subm1.column_start_pos == 1 && subm2.row_end_pos - subm2.row_start_pos == 1 && subm2.column_end_pos - subm2.column_start_pos == 1) {
int A11 = subm1.matrix[subm1.row_start_pos][subm1.column_start_pos]
int A12 = subm1.matrix[subm1.row_start_pos][subm1.column_end_pos]
int A21 = subm1.matrix[subm1.row_end_pos][subm1.column_start_pos]
int A22 = subm1.matrix[subm1.row_end_pos][subm1.column_end_pos]
int B11 = subm2.matrix[subm2.row_start_pos][subm2.column_start_pos]
int B12 = subm2.matrix[subm2.row_start_pos][subm2.column_end_pos]
int B21 = subm2.matrix[subm2.row_end_pos][subm2.column_start_pos]
int B22 = subm2.matrix[subm2.row_end_pos][subm2.column_end_pos]
int M1 = (A11 + A22) * (B11 + B22)
int M2 = (A21 + A22) * B11
int M3 = A11 * (B12 - B22)
int M4 = A22 * (B21 - B11)
int M5 = (A11 + A12) * B22
int M6 = (A21 - A11) * (B11 + B12)
int M7 = (A12 - A22) * (B21 + B22)
//for C11
m.matrix[subm1.row_start_pos][subm2.column_start_pos] += M1 + M4 - M5 + M7
//for C12
m.matrix[subm1.row_start_pos][subm2.column_end_pos] += M3 + M5
//for C21
m.matrix[subm1.row_end_pos][subm2.column_start_pos] += M2 + M4
//for C22
m.matrix[subm1.row_end_pos][subm2.column_end_pos] += M1 - M2 + M3 + M6
return
}
int rsp1 = subm1.row_start_pos
int rep1 = subm1.row_end_pos
int csp1 = subm1.column_start_pos
int cep1 = subm1.column_end_pos
int rsp2 = subm2.row_start_pos
int rep2 = subm2.row_end_pos
int csp2 = subm2.column_start_pos
int cep2 = subm2.column_end_pos
//C11 = A11 * B11 + A12 * B21
//get A11 * B11
//get A11
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B11
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
strassen_sub_multi_matrix(subm1, subm2, m)
//get A12 * B21
//get A12
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B21
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
strassen_sub_multi_matrix(subm1, subm2, m)
//C12 = A11 * B12 + A12 * B22
//get A11 * B12
//get A11
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B12
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
strassen_sub_multi_matrix(subm1, subm2, m)
//get A12 * B22
//get A12
subm1.row_start_pos = rsp1
subm1.row_end_pos = rsp1 + (rep1 - rsp1) / 2
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B22
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
strassen_sub_multi_matrix(subm1, subm2, m)
//for C21 = A21 * B11 + A22 * B21
//get A21 * B11
//get A21
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B11
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
strassen_sub_multi_matrix(subm1, subm2, m)
//get A22
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B21
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2
subm2.column_end_pos = csp2 + (cep2 - csp2) / 2
strassen_sub_multi_matrix(subm1, subm2, m)
//for C22 = A21 * B12 + A22 * B22
//get A21 * B12
//get A21
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1
subm1.column_end_pos = csp1 + (cep1 - csp1) / 2
//get B12
subm2.row_start_pos = rsp2
subm2.row_end_pos = rsp2 + (rep2 - rsp2) / 2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
strassen_sub_multi_matrix(subm1, subm2, m)
//get A22 * B22
//get A22
subm1.row_start_pos = rsp1 + (rep1 - rsp1) / 2 + 1
subm1.row_end_pos = rep1
subm1.column_start_pos = csp1 + (cep1 - csp1) / 2 + 1
subm1.column_end_pos = cep1
//get B22
subm2.row_start_pos = rsp2 + (rep2 - rsp2) / 2 + 1
subm2.row_end_pos = rep2
subm2.column_start_pos = csp2 + (cep2 - csp2) / 2 + 1
subm2.column_end_pos = cep2
strassen_sub_multi_matrix(subm1, subm2, m)
}
int main(int argc,char** argv) {
struct Matrix m1
struct Matrix m2
create_random_matrix(m1,2,2)
create_random_matrix(m2,2,2)
print_matrix(m1)
print_matrix(m2)
normal_multi_matrix(m1, m2)
normal_sub_multi_matrix(m1, m2)
delete_matrix(m1)
delete_matrix(m2)
system("pause")
return 0
}
