http://poj.org/problem?id=3233
Matrix Power Series
Time Limit: 3000MS | Memory Limit: 131072K | |
Total Submissions: 13472 | Accepted: 5809 |
Description
Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.
Input
The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.
Output
Output the elements of S modulo m in the same way as A is given.
Sample Input
2 2 4 0 1 1 1
Sample Output
1 2 2 3
Source
POJ Monthly--2007.06.03, Huang, Jinsong
(强烈吐槽优快云的格式混乱)
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<cctype>
#include<cmath>
#include<string>
#include<cstring>
#include<queue>
#include<vector>
#define sqr(x) (x)*(x)
#define INF 0x1f1f1f1f
#define PI 3.1415926535
#define mm
using namespace std;
int mod,n,k;
struct mat
{
int x,y;
int m[33][33];
}a,e,zero; //a为输入矩阵,e为单位矩阵,zero为零矩阵
struct mat2_of_mat
{
mat a,b,
c,d;
}m,t; //t记录answer
void debug(mat o)
{
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
printf("%d ",o.m[i][j]);
}
printf("\n");
}
}
void init()
{
memset(&zero,0,sizeof(zero));
memset(&e,0,sizeof(e));
for (int i=0;i<n;i++)
{
e.m[i][i]=1;
}
a.x=a.y=e.x=e.y=zero.x=zero.y=n;
t.a=a; t.b=e; //就是这里WA了N个小时,a,d位置上应为单位矩阵,b,c应置零
t.c=zero; t.d=e;
m=t;
t.a=e; t.b=zero; //所以有这样的改动
k++;
}
mat mul_mat(mat a, mat b) //矩阵的乘法
{
mat temp;
memset(&temp,0,sizeof(temp));
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
for (int k=0;k<n;k++)
{
temp.m[i][j]+=((a.m[i][k]*b.m[k][j])%mod); //记得用同余定理
}
}
}
return temp;
}
mat plus_mat(mat a, mat b) //矩阵的加法
{
mat temp;
memset(&temp,0,sizeof(temp));
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
temp.m[i][j]=(a.m[i][j]+b.m[i][j])%mod;
}
}
return temp;
}
mat2_of_mat mul_mat2_of_mat(mat2_of_mat x, mat2_of_mat y) //矩阵的二阶矩阵的乘法
{
mat2_of_mat temp;
memset(&temp,0,sizeof(temp));
temp.a=plus_mat(mul_mat(x.a,y.a),mul_mat(x.b,y.c)); //这里改成temp.a=mul_mat(x.a,y.a);就WA了,不知道为什么,c位置是零矩阵啊,知道原因的请告诉我,谢谢
temp.b=plus_mat(mul_mat(x.a,y.b),x.b);
temp.c=zero;
temp.d=e;
return temp;
}
void qp() //quick_power
{
while(k)
{
/*
cout<<"m.a"<<endl;
debug(m.a);
cout<<"m.b"<<endl;
debug(m.b);
cout<<"m.c"<<endl;
debug(m.c);
cout<<"m.d"<<endl;
debug(m.d);
*/
if (k&1) t=mul_mat2_of_mat(t,m);
m=mul_mat2_of_mat(m,m);
/*
cout<<"t.a"<<endl;
debug(t.a);
cout<<"t.b"<<endl;
debug(t.b);
cout<<"t.c"<<endl;
debug(t.c);
cout<<"t.da"<<endl;
debug(t.d);
*/
k>>=1;
}
}
int main()
{
scanf("%d%d%d",&n,&k,&mod);
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
scanf("%d",&a.m[i][j]);
}
}
init();
qp();
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
if (i==j)
{
printf("%d ",(t.b.m[i][j]-1+mod)%mod);//这里要减去单位矩阵
}
else
{
printf("%d ",t.b.m[i][j]);
}
}
printf("\n");
}
return 0;
}
后来发现不用写矩阵的二阶矩阵的快速幂,直接矩阵快速幂即可(教练,我想学线代……)
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<cctype>
#include<cmath>
#include<string>
#include<cstring>
#include<queue>
#include<vector>
#define sqr(x) (x)*(x)
#define INF 0x1f1f1f1f
#define PI 3.1415926535
#define mm
using namespace std;
int mod,n,k;
struct mat
{
int x,y;
int m[70][70];
}a,e,zero,m,t;
void debug(mat o)
{
for (int i=0;i<n*2;i++)
{
for (int j=0;j<n*2;j++)
{
printf("%d ",o.m[i][j]);
}
printf("\n");
}
}
void init()
{
memset(&zero,0,sizeof(zero));
memset(&e,0,sizeof(e));
for (int i=0;i<n*2;i++)
{
e.m[i][i]=1;
}
m=t=e;
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
m.m[i][j]=a.m[i][j];
}
m.m[i][i+n]=1;
}
/*
debug(m);
debug(t);
*/
k++;
}
mat mul_mat(mat a, mat b)
{
mat temp;
memset(&temp,0,sizeof(temp));
for (int i=0;i<n*2;i++)
{
for (int j=0;j<n*2;j++)
{
for (int k=0;k<n*2;k++)
{
temp.m[i][j]=(temp.m[i][j]+a.m[i][k]*b.m[k][j])%mod;
}
}
}
return temp;
}
void qp()
{
while(k)
{
/*
cout<<"m"<<endl;
debug(m);
*/
if (k&1) t=mul_mat(t,m);
m=mul_mat(m,m);
/*
cout<<"t"<<endl;
debug(t);
*/
k>>=1;
}
}
int main()
{
scanf("%d%d%d",&n,&k,&mod);
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
scanf("%d",&a.m[i][j]);
}
}
init();
qp();
for (int i=0;i<n;i++)
{
for (int j=0;j<n;j++)
{
if (i==j)
{
printf("%d ",(t.m[i][j+n]-1+mod)%mod);
}
else
{
printf("%d ",t.m[i][j+n]);
}
}
printf("\n");
}
return 0;
}