前言
不知道大家有没有看到这次的题目——“数学杂谈”,当然既然是杂谈了,那么我们就来闲聊一下这次的内容吧。
其实 矩阵加速 看上去是一个很(玄学)哲学的东西,要是突然上天眷顾,那么就能够想出来,可是要是(点背)emmm,这其实也就没多大用了。
因此这个东西真的就只是像他所说的,只是起到了一个 加速 的作用而已,如果真的想到了正解,这也就没多大用了吧。吧。吧。(除非真的有些防ak的题。。。)不过我们还是尽量以拓宽知识面为主,能多啃点就多啃点咯,(要是到时候真的用这个骗到了分的话,岂不是美滋滋、、、)
题目
emmm其实没有什么题目啦,就正如大家所熟知的 斐波那契数列和前n项和%指定数 而已。
哈哈哈哈(逃)
解析
当然这个斐波那契不一般,且看数据:
要是这个真的按平时的我们的方法来算的话,恐怕面对的只有 TLE 罢了,因此在这里就有了矩阵加速的一席之地了。
在讲解这道题之前,让我们先来看下矩阵相乘吧。两个矩阵的乘法仅当第一个矩阵A的列数和第二个矩阵B的行数相等时才能定义(做乘法)。如A是m×n矩阵,B是n×p矩阵,它们的乘积C是一个m×p矩阵C=(cij),它的任意一个元素值为:
矩阵运算之乘法
两个矩阵的乘法仅当——
第一个矩阵A的列数和第二个矩阵B的行数相等时才能定义(做乘法)。如A是m×n矩阵,B是n×p矩阵,它们的乘积C是一个m×p矩阵C=(cij)(行取第一个矩阵的行,列取第二个矩阵的列),它的任意一个元素值为:
或者说是这样:
其实总结起来就是一句话:将第一个矩阵转过来,如果他们高度相等(就是行数相等),那么就能够做乘法了。
而也很好理解,就是第一个矩阵A的第i行和第二个矩阵B的第j列对应相乘的结果。 (这个十分重要)
(划重点哦~~~)
这个呢就是一个矩阵相乘的例子,大家可以结合着上面的总结看一下,然后再自己找几个矩阵相乘练一下,这个真的十分重要,最好能练成那种光看就能在脑子里推出来的那种哦~~(当然不能急于求成,慢慢来)
单位矩阵
当然在这道题中我们还需要补充一个知识——单位矩阵。
那么单位矩阵是什么呢,其实就是矩阵的单位(哈哈哈哈)。
当然它是一个方阵,也就是行列相同的矩阵(跟正方形好像也差不多)
意思就是说不管什么矩阵乘上它都是它本身,这么说起来大家有没有想到“1”呢?没错,他跟1确实差不多(至少我是这么理解的)。
当然这里就直接给出来了吧:从左上角到右下角都为1,其他为0。例如:
这就是一个行列为3的单位矩阵。
特殊变换
(实在不知道该取什么名字了,大家凑合着看吧。。。)
我们在这道题中定义一个1行2列的矩阵,分别表示和
。(要求和的话就在前面加上一列表示
吧)
接下来我们思考下如何将这个矩阵乘另外一个矩阵来得到:
如果说要从f1,f2变成f2,f3,是不是第一个f1就不能要了,取而代之的就是f2。
也就是说——这个矩阵的(1,1)是等于0的(因为(1,1)要跟f1相乘,作为ans矩阵的第一个元素的部分),
而(2,1)则应该为1(因为(2,1)跟f2相乘,作为ans矩阵第一个元素的一个部分)。
后面则以此类推——因为f3 = f1 + f2,所以这个矩阵的(1,2)和(2,2)都应该为1。
也就是说这个矩阵应该是长这样的:
(如果没看得很懂的童鞋们可以自己再写一下这两个矩阵,比划一下,这个必须搞懂,不然的话求和的话就比较恼火了)
在这里就只提一下,因为si+1是等于si再加上fi-1和fi的,因此第一列全都为1:
弄懂了这个,我们就可以通过矩阵快速幂来做了!,也就是将原矩阵乘上n - 2个变换矩阵(就是变换矩阵的n-2次方)就ok啦!!
至于为什么是n-2,emmm,我觉得这个我就不用说了吧。。。你们自己再想想吧。(逃)
参考代码
Fibonacci 第 n 项:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define LL long long
#define min(a, b) a < b ? a : b
#define max(a, b) a > b ? a : b
void read (int &x){
x = 0;
char c = getchar ();
while (c < '0' || c > '9')
c = getchar ();
while (c >= '0' && c <= '9'){
x = (x << 1) + (x << 3) + c - 48;
c = getchar ();
}
}
void Print (LL x){
if (x < 0){
putchar ('-');
x = ~x + 1;
}
if (x / 10) Print (x / 10);
putchar (x % 10 + 48);
}
int n, mod;
struct Matrix{
int n, m;
LL a[2][2];
Matrix (){
n = m = 0;
memset (a, 0, sizeof (a));
}
void print (){
Print (a[0][1]);
putchar (10);
}
Matrix operator * (const Matrix &rhs){
Matrix k;
k.n = n; k.m = rhs.m;
for (int i = 0; i < k.n; i++)
for (int j = 0; j < k.m; j++)
for (int kk = 0; kk < m; kk++)
k.a[i][j] = (k.a[i][j] + (a[i][kk] * rhs.a[kk][j] % mod)) % mod;
return k;
}
};
Matrix qkpow (Matrix a, int b){
Matrix tot;
tot.n = tot.m = 2;
for (int i = 0; i < tot.n; i++)
tot.a[i][i] = 1;
while (b){
if (b & 1)
tot = tot * a;
a = a * a;
b >>= 1;
}
return tot;
}
int main (){
read (n); read (mod);
Matrix ans, a;
ans.n = 1, ans.m = 2;
a.n = a.m = 2;
a.a[0][0] = 0, a.a[0][1] = a.a[1][0] = a.a[1][1] = 1;
ans.a[0][0] = ans.a[0][1] = 1;
ans = ans * qkpow (a, n - 2);
ans.print ();
}
Fibonacci前n项和:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define LL long long
#define min(a, b) a < b ? a : b
#define max(a, b) a > b ? a : b
void read (int &x){
x = 0;
char c = getchar ();
while (c < '0' || c > '9')
c = getchar ();
while (c >= '0' && c <= '9'){
x = (x << 1) + (x << 3) + c - 48;
c = getchar ();
}
}
void Print (LL x){
if (x < 0){
putchar ('-');
x = ~x + 1;
}
if (x / 10) Print (x / 10);
putchar (x % 10 + 48);
}
int n, mod;
struct Matrix{
int n, m;
LL a[5][5];
Matrix (){
memset (a, 0, sizeof (a));
n = 0; m = 0;
}
void print (){
Print (a[1][1]);
putchar (10);
}
Matrix operator * (const Matrix &rhs){
Matrix k;
k.n = n; k.m = rhs.m;
for (int i = 1; i <= k.n; i++)
for (int j = 1; j <= k.m; j++)
for (int kk = 1; kk <= m; kk++)
k.a[i][j] = (k.a[i][j] + (a[i][kk] * rhs.a[kk][j]) % mod) % mod;
return k;
}
};
Matrix qkpow (int b){
Matrix a, tot;
a.n = a.m = tot.n = tot.m = 3;
a.a[1][1] = a.a[2][1] = a.a[2][3] = a.a[3][1] = a.a[3][2] = a.a[3][3] = 1;
for (int i = 1; i <= tot.n; i++)
tot.a[i][i] = 1;
while (b){
if (b & 1)
tot = tot * a;
a = a * a;
b >>= 1;
}
return tot;
}
int main (){
read (n); read (mod);
Matrix ans; ans.n = 1, ans.m = 3;
ans.a[1][1] = 2, ans.a[1][2] = 1, ans.a[1][3] = 1;
ans = ans * qkpow (n - 2);
ans.print ();
}
emmm这个代码仅供参考,能不能对就不能保证了。。。(哈哈哈哈溜)