题目描述
题解
考虑分治[L,R][L,R][L,R]内的答案,连续区间一端位于[L,mid][L,mid][L,mid],另一端位于[mid+1,R][mid+1,R][mid+1,R]的答案有多少。其余的递归处理。我们知道一段连续区间的特征是Max−Min=R−L.Max-Min=R-L.Max−Min=R−L.
假设Maxi=max(ai,ai+1,...,amid). Mini=min(ai,ai+1,...,amid).Max_i=max(ai,a_{i+1},...,a_{mid}).\ \ Min_i=min(a_i,a_{i+1},...,a_{mid}).Maxi=max(ai,ai+1,...,amid). Mini=min(ai,ai+1,...,amid).
Maxj=min(aj,aj−1,...,amid), Minj=min(aj,aj−1,...,amid)\ \ \ \ \ \ Max_j=min(a_j,a_{j-1},...,a_{mid}),\ \ Min_j=min(a_j,a_{j-1},...,a_{mid}) Maxj=min(aj,aj−1,...,amid), Minj=min(aj,aj−1,...,amid)
那么我们可以根据最大值和最小值的关系得到如下四个等式,每一个连续区间[i,j]必然满足其中关系的一个。{Maxi−Mini=j−i①Maxi−Minj=j−i②Maxj−Minj=j−i③Maxj−Mini=j−i④\begin{cases} Max_i-Min_i=j-i①\\Max_i-Minj=j-i②\\Max_j-Minj=j-i③\\Maxj-Mini=j-i④\end{cases}⎩⎪⎪⎪⎨⎪⎪⎪⎧Maxi−Mini=j−i①Maxi−Minj=j−i②Maxj−Minj=j−i③Maxj−Mini=j−i④
在这里,①③和②④本质相同,只需要对数组做一遍翻转操作即可。我们来考虑如何计算①和②。
-
对于①的情况,我们可以枚举i∈[L,mid],i∈[L,mid],i∈[L,mid],计算出j=Maxi−Mini+ij=Max_i-Min_i+ij=Maxi−Mini+i.
判断j∈[Mid+1,R]j∈[Mid+1,R]j∈[Mid+1,R]是否满足,且是否满足如下条件即可:{Maxi>MaxjMini<Minj\begin{cases} Max_i>Max_j\\Min_i<Min_j\end{cases}{Maxi>MaxjMini<Minj -
对于②的情况,我们可以枚举i∈[L,Mid]i∈[L,Mid]i∈[L,Mid],计算出有多少对j+Minj=i+Minij+Min_j=i+Min_ij+Minj=i+Mini即可。
其中j的统计用数组标记,我们用双指针找到合法的j的范围然后统计到数组里面累加即可。这样对于i也能够直接查找。由于有两个约束条件:{Maxi>MaxjMini>Minj\begin{cases}Max_i>Max_j\\Min_i>Min_j\end{cases}{Maxi>MaxjMini>Minj
有指针rrr枚举满足第二个条件的右端点,用lll去掉非法的左端点。这里有一点莫队的思想。
代码其实十分简短:
#include <bits/stdc++.h>
using namespace std;
const int N = 3000000;
int n;
long long ans = 0;
int Max[N], Min[N], a[N], cnt[N];
void work(int L,int R,int Mid)
{
Max[Mid] = Min[Mid] = a[Mid];
for (int i=Mid-1;i>=L;--i)
{
Max[i] = max(Max[i+1],a[i]);
Min[i] = min(Min[i+1],a[i]);
}
Max[Mid+1] = Min[Mid+1] = a[Mid+1];
for (int i=Mid+2;i<=R;++i)
{
Max[i] = max(Max[i-1],a[i]);
Min[i] = min(Min[i-1],a[i]);
}
//maxi-mini=j-i
for (int i=L;i<=Mid;++i)
{
int j = i+Max[i]-Min[i];
if (j > Mid && j <= R && Max[i] > Max[j] && Min[i] < Min[j])
ans ++;
}
//maxi-minj=j-i
int l =Mid+1, r = Mid;
for (int i=Mid;i>=L;--i)
{
while (Max[r+1] < Max[i] && r < R)
r ++, cnt[r+Min[r]] ++;
while (Min[l] > Min[i] && l <= r)
cnt[l+Min[l]] --, l ++;
ans += 1LL * cnt[i+Max[i]];
}
while (l <= r) cnt[l+Min[l]] --, l ++;
return;
}
void Solve(int L,int R)
{
if (L == R) ans ++;
if (L >= R) return;
int mid = L+R >> 1;
Solve(L,mid), Solve(mid+1,R);
int Mid = L+R >> 1;
work(L,R,mid);
reverse(a+L,a+R+1);
if ((L+R+1) % 2 == 0) work(L,R,mid);
else work(L,R,mid-1);
reverse(a+L,a+R+1);
return;
}
int main(void)
{
scanf("%d", &n);
for (int i=1;i<=n;++i)
scanf("%d", a+i);
Solve(1,n);
cout<<ans<<endl;
return 0;
}