第一次写poj的题目的题解呢QωQ\color{pink}Q\omega QQωQ。
题意简述
给定一个n×nn\times nn×n的矩阵AAA,求A+A2+A3⋯+AkA+A^2+A^3\cdots +A^kA+A2+A3⋯+Ak的值(其中k<=1e9k<=1e9k<=1e9)。每一项膜mmm。
数据
输入
2 2 4//n,k,m
0 1
1 1
输出
1 2
2 3
思路
特意注明k<=1e9k<=1e9k<=1e9,只是想说明我们都会写的暴力过不去了。。。
这不是在欺负女孩子(我)是在干嘛。。。
但是,正好最近学因式分解。我们设S(k)=A+A2+A3⋯+AkS(k)=A+A^2+A^3\cdots +A^kS(k)=A+A2+A3⋯+Ak,那么:
当kkk为偶数的时候,会发现珂以提一个公因式S(k/2)S(k/2)S(k/2),此时S(k)S(k)S(k)化为
=S(k/2)(I(n)+Ak/2)=S(k/2)(I(n)+A^{k/2})=S(k/2)(I(n)+Ak/2)。
(其中I(n)I(n)I(n)表示n×nn\times nn×n的单位矩阵)
当kkk为奇数的时候,S(k)=Ak+S(k−1)S(k)=A^k+S(k-1)S(k)=Ak+S(k−1)
注意到此时k−1k-1k−1是偶数,根据上面的变换,原式化为
=Ak+S(k/2)(I(n)+Ak/2)=A^k+S(k/2)(I(n)+A^{k/2})=Ak+S(k/2)(I(n)+Ak/2)
然后会发现,每次规模都会缩小一半,就是O(logk)O(logk)O(logk)的。在加上矩阵快速幂的复杂度,总复杂度就是O(logk×logk×n3)=O(log2kn3)O(logk\times logk\times n^3)=O(log^2kn^3)O(logk×logk×n3)=O(log2kn3)。
代码:
#include<cstdio>
#define N 35
using namespace std;
int n,k,m;
struct mat
{
int m[N][N];
int* operator[](int i)//封装下标[]运算符
{
return m[i];
}
mat(int x)//初始值设为x
{
for(int i=0;i<N;i++)
{
for(int j=0;j<N;j++)
{
m[i][j]=x;
}
}
}
void Set(int x=0)//这个方便后期设置
{
for(int i=0;i<N;i++)
{
for(int j=0;j<N;j++)
{
m[i][j]=x;
}
}
}
void Identity()//设置成单位矩阵
{
Set(0);
for(int i=0;i<N;i++)
{
m[i][i]=1;
}
}
};
mat operator+(mat x,mat y)//矩阵加法
{
mat s(0);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
s[i][j]=(x[i][j]+y[i][j])%m;
}
}
return s;
}
mat operator*(mat x,mat y)//矩阵乘法
{
mat ans(0);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
ans[i][j]=0;
for(int k=1;k<=n;k++)
{
ans[i][j]+=x[i][k]*y[k][j];
ans[i][j]%=m;
}
}
}
return ans;
}
mat operator^(mat x,int p)//矩阵快速幂
{
mat ans(0);ans.Identity();
while(p)
{
if (p&1) ans=ans*x;
x=x*x,p>>=1;
}
return ans;
}
mat a(0);
void Input()
{
scanf("%d%d%d",&n,&k,&m);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
scanf("%d",&a[i][j]);
}
}
}
mat calc(int k)//计算S(k)
{
if (k==1) return a;//边界(这个千万别忘了!!!)
mat ans(0);
ans.Identity();
ans=ans+(a^(k>>1));//I(n)+A^(k>>1)
ans=ans*calc(k>>1);//S(k>>1)×(I(n)+A^(k>>1))
if (k&1) ans=ans+(a^k);//由于>>1的值奇数和偶数没区别,奇数和偶数的区别就只是在这里
return ans;
}
void Solve()
{
mat tmp=calc(k);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
tmp[i][j]%=m;//为了保险(具体有没有用。。。我也不知道。。。)
printf("%d ",tmp[i][j]);
}
putchar('\n');
}
}
int main()
{
Input();
Solve();
return 0;
}