牛客练习赛的一道题目
先想到一个思路,对于一个固定的左端点,右边第一个比他大的数下标为i,第二个为j,则右端点必然在[i,j-1]之间,同时对于每一个右端点满足要求的左端点也在一个区间中,还需要检查固定的左端点是否在该区间中。直接检查的话会超时,可以检索每个点为右端点,对应的区间中有多少成立的左端点转化为区间和问题,可用树状数组解决。而某个点A作为左端点成立时仅在右端点为[i,j-1]是成立,只需要在枚举到以i为右端点时将A激活(即树状数组add(A,1)),在枚举j时再消除他的影响即可。最后还要解决对于i找前两个a[i]大的数的下标和后两个比a[i]小的下标,可以分别按照a[i]值由大到小,由小到大插入下标,利用权值线段树或set查找前驱和后继。
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int n,a[maxn],pre[maxn][2],suf[maxn][2],c[maxn],sa[maxn];
void add(int x,int v)
{
while(x<=n)
{
c[x]+=v;
x+=x&(-x);
}
}
int query(int x)
{
int ans=0;
while(x)
{
ans+=c[x];
x-=x&(-x);
}
return ans;
}
int tree[4*maxn];
void insert(int o,int l,int r,int x)
{
if(l==r)
{
tree[o]+=1;
return;
}
int lc=2*o,rc=2*o+1,mid=l+(r-l)/2;
if(x<=mid) insert(lc,l,mid,x);
else insert(rc,mid+1,r,x);
tree[o]=tree[lc]+tree[rc];
}
int findpre(int o,int l,int r,int x)
{
if(x<1) return 0;
if(l==r)
{
if(tree[o])
return l;
else return 0;
}
int lc=2*o,rc=2*o+1,mid=l+(r-l)/2;
if(x<=mid) return findpre(lc,l,mid,x);
else
{
int ans=0;
if(tree[rc])
ans=findpre(rc,mid+1,r,x);
if(!ans)
ans=findpre(lc,l,mid,x);
return ans;
}
}
int findsuf(int o,int l,int r,int x)
{
if(x>n) return n+1;
if(l==r)
{
if(tree[o]) return l;
else return n+1;
}
int lc=2*o,rc=lc+1,mid=l+(r-l)/2;
if(x>mid) return findsuf(rc,mid+1,r,x);
else
{
int ans=n+1;
if(tree[lc]) ans=findsuf(lc,l,mid,x);
if(ans>n)
ans=findsuf(rc,mid+1,r,x);
return ans;
}
}
vector<int> mp[maxn][2];
int main()
{
//freopen("in.txt","r",stdin);
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
sa[a[i]]=i;
}
for(int i=1;i<=n;i++)
{
suf[sa[i]][0]=findsuf(1,1,n,sa[i]+1);
suf[sa[i]][1]=findsuf(1,1,n,suf[sa[i]][0]+1);
insert(1,1,n,sa[i]);
}
memset(tree,0,sizeof(tree));
for(int i=n;i>=1;i--)
{
pre[sa[i]][0]=findpre(1,1,n,sa[i]-1);
pre[sa[i]][1]=findpre(1,1,n,pre[sa[i]][0]-1);
// printf("%d %d\n",pre[sa[i]][0],pre[sa[i]][1]);
insert(1,1,n,sa[i]);
}
for(int i=1;i<=n;i++)
{
// printf("%d %d %d\n",i,pre[i][0],pre[i][1]);
mp[suf[i][0]][0].push_back(i);
mp[suf[i][1]][1].push_back(i);
}
long long ans=0;
for(int i=1;i<=n;i++)
{
for(int j=0;j<mp[i][0].size();j++)
{
add(mp[i][0][j],1);
}
for(int j=0;j<mp[i][1].size();j++)
{
add(mp[i][1][j],-1);
}
ans+=query(pre[i][0])-query(pre[i][1]);
}
printf("%lld\n",ans);
return 0;
}