题意:
function z(array a, integer k): if length(a) < k: return 0 else: b = empty array ans = 0 for i = 0 .. (length(a) - k): temp = a[i] for j = i .. (i + k - 1): temp = max(temp, a[j]) append temp to the end of b ans = ans + temp return ans + z(b, k)
给出a和k求z(a,k)
分析:
先定义cnt[l][r]表示区间[l,r]的子区间被计算的次数,比如k=3时,cnt[3,7]=4,即第一层的[3,5],[4,6],[5,7]和这三个区间的最大值组成的第二层的唯一的结果,可以发现这样的递推关系cnt[l][r]-cnt[l][r-1] = 以r为结尾的区间被计算的次数 = ((r-l-1) - 1)/(k-1), 且易发现cnt只和区间长度有关,所以可以用cnt[x] = cnt[x-1] + (x-1)/(k-1)计算出,x为区间长度
再来考虑如何计算答案设以a[i]为最大值的最大区间为[l,r],a[i]的有效次数等于cnt[r-l+1]-cnt[i-l]-cnt[r-i],即所有最终会被a[i]影响的个数减去还没有包含到a[i]的个数,每个i的[l,r]可以单调栈O(n)求出
此外,有一个非常关键的情况就是对于a[i]=a[j]且a[k]<a[i](i<k<j)的情况,即有多个值相同的数之间没有比他们大的数,此时a[i]和a[j]最后会在某处会被归并成同一个最大值,但是会被重复计算,因为那个最大值即被认为是a[i]造成的也被认为是a[j]造成的,所以我们在单调栈处理区间时可以把a[i]看成pair(a[i],-i)来进行大小比较,即在值相同的情况下,我们认为前面的会覆盖后面的而后面的会消失,这样相同值归并产生的区间就只会在这个值第一次出现的位置被计算一次,就可以去掉重复情况
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#define pii pair<int, int>
#define fi first
#define se second
#define mk make_pair
#define pb push_back
#define sc(x) scanf("%d", &x);
#define scl(x) scanf("%I64d", &x);
#define frein freopen("in.txt", "r", stdin);
#define freout freopen("out.txt", "w", stdout);
#define freout1 freopen("out1.txt", "w", stdout);
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int mod = 1e9+7;
const int maxn = 1e6+10;
int n, k, cnt[maxn], a[maxn], L[maxn], R[maxn];
stack<pii> stk;
int main(){
//frein;
//freout;
sc(n); sc(k);
for(int i = 1; i <= n; i++) sc(a[i]);
stk.push({INF,0});
for(int i = 1; i <= n; i++){
while(stk.top().fi < a[i]) stk.pop();
L[i] = stk.top().se+1;
stk.push({a[i],i});
}
while(stk.size()) stk.pop();
stk.push({INF,n+1});
for(int i = n; i >= 1; i--){
while(stk.top().fi <= a[i]) stk.pop();
R[i] = stk.top().se-1;
stk.push({a[i],i});
}
for(int i = 1; i <= n; i++){
cnt[i] = (cnt[i-1] + (i-1)/(k-1))%mod;
}
LL ans = 0;
for(int i = 1; i <= n; i++){
int t = (cnt[R[i]-L[i]+1] - cnt[i-L[i]] - cnt[R[i]-i])%mod;
ans += 1LL*t*a[i]%mod;
}
printf("%I64d\n", (ans%mod + mod)%mod);
return 0;
}