题目链接:
Calculate the Function
题意:
给你一个区间 [L, R] ,再给你一个递推关系 f(L) = A[L] ; f(L+1) = A[L+1] ; f(x) = f(x - 1) + f(x - 2) * A[x] (x > 2) ,然后让你求 f(R)的值。
题解:
看到有递推式,首先想到矩阵快速幂,那么递推矩阵是什么呢?
就是这个矩阵,那么大家会发现这个矩阵不是定值,而实随着x的变化,矩阵发生变化,那么就不可能用矩阵快速幂了。但是暴力
矩阵相乘会超时,所以可以用线段是来优化。线段树的每个节点维护一个矩阵乘积,当 R - L <= 2时候直接特判输出,大于2时候
就可以通过线段树查询区间内矩阵相乘,最后再用乘以
到r之间的矩阵乘积就可以得出来
了。
AC代码:
#include<bits/stdc++.h>
#define up(i, x, y) for(ll i = x; i <= y; i++)
#define down(i, x, y) for(ll i = x; i >= y; i--)
#define maxn ((ll)1e5 + 10)
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
ll mod = (ll)1e9 + 7;
struct node
{
ll mat[3][3]; //节点维护矩阵
ll l, r;
}t[maxn * 4 + 1]; //定义线段树节点
ll arr[maxn];
node ans;
void fun( ll A[][3] , ll B[][3] ,ll C[][3]) //模拟矩阵相乘
{
C[1][1] = (A[1][1] * B[1][1] + A[1][2] * B[2][1]) % mod;
C[1][2] = (A[1][1] * B[1][2] + A[1][2] * B[2][2]) % mod;
C[2][1] = (A[2][1] * B[1][1] + A[2][2] * B[2][1]) % mod;
C[2][2] = (A[2][1] * B[1][2] + A[2][2] * B[2][2]) % mod;
}
void build(ll k, ll l, ll r)
{
t[k].l = l, t[k].r = r;
if(l == r)
{
scanf("%lld", &t[k].mat[1][2]);
t[k].mat[1][1] = 1;
t[k].mat[2][2] = 0;
t[k].mat[2][1] = 1;
arr[l] = t[k].mat[1][2];
return ;
}
ll m = (l + r) / 2;
build(k * 2, l, m);
build(k * 2 + 1,m + 1, r);
fun(t[k * 2 + 1].mat, t[k * 2].mat, t[k].mat); //递推需要按照矩阵乘积来
}
void qur(ll k, ll l, ll r)
{
if(l <= t[k].l && t[k].r <= r)
{
if(ans.mat[1][1] == 0)
{
ans.mat[1][1] = t[k].mat[1][1];
ans.mat[1][2] = t[k].mat[1][2];
ans.mat[2][1] = t[k].mat[2][1];
ans.mat[2][2] = t[k].mat[2][2];
}
else
{
ll tt[3][3];
tt[1][1] = ans.mat[1][1];
tt[1][2] = ans.mat[1][2];
tt[2][1] = ans.mat[2][1];
tt[2][2] = ans.mat[2][2];
fun(t[k].mat, tt, ans.mat);
}
return ;
}
ll m = (t[k].r + t[k].l) / 2;
if(l <= m) qur(k * 2, l, r);
if(r >= m + 1) qur(k * 2 + 1, l, r);
}
int main()
{
ll T; while(~scanf("%lld", &T))
{
while(T--)
{
ll n, m; scanf("%lld %lld", &n, &m);
build(1, 1, n);
while(m--)
{
ll x, y; scanf("%lld %lld", &x, &y);
if(y - x <= 1)
{
if(y == x)
{
printf("%lld\n", arr[y]);
}
else
{
printf("%lld\n", arr[y]);
}
continue;
}
ans.mat[1][1] = 0;
qur(1, x + 2, y); // 查询区间矩阵乘积
ll tmp = (ans.mat[1][1] * arr[x+1] + ans.mat[1][2] * arr[x]) % mod;
printf("%lld\n", tmp);
}
}
}
}