题目大意
给出一个长度为 n n n 的序列 a a a,问有多少个区间 [ l , r ] ( 1 ≤ l < r ≤ n ) [l,r] (1\le l < r\le n) [l,r](1≤l<r≤n) 满足 a l , a r a_l,a_r al,ar 在区间中都只出现了一次
解题思路
签到题
首先,我们记 p r e i , n x t i pre_i,nxt_i prei,nxti 分别表示 i i i 前后第一个与 a i a_i ai 相同的位置
接着,我们可以发现每个位置作为右端点对答案的贡献就是所有满足
n
x
t
j
>
i
nxt_j>i
nxtj>i 且
j
∈
(
p
r
e
i
,
i
)
j \in (pre_i,i)
j∈(prei,i) 的
j
j
j 的个数。注意这里小括号表示开区间
那么,我们可以用主席树维护
n
x
t
nxt
nxt,并用
O
(
log
n
)
\mathcal O(\log n)
O(logn) 的复杂度计算单点的贡献
总时间复杂度 O ( n log n ) \mathcal O(n \log n) O(nlogn)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int Maxn=2e5+10;
const int Maxm=5e6+10;
int ls[Maxm],rs[Maxm];
int sum[Maxm],c[Maxn];
int nxt[Maxn],pre[Maxn];
int root[Maxn],a[Maxn];
int n,idcnt;
long long ans;
inline int read()
{
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0' && ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
return s*w;
}
inline void push_up(int x)
{
sum[x]=sum[ls[x]]+sum[rs[x]];
}
void modify(int &x,int y,int l,int r,int pos)
{
x=++idcnt,sum[x]=sum[y]+1;
ls[x]=ls[y],rs[x]=rs[y];
if(l==r)return;
int mid=(l+r)>>1;
if(pos<=mid)modify(ls[x],ls[y],l,mid,pos);
else modify(rs[x],rs[y],mid+1,r,pos);
push_up(x);
}
int query(int x,int y,int l,int r,int u,int v)
{
if(u>v)return 0;
if(u<=l && r<=v)return sum[x]-sum[y];
int mid=(l+r)>>1,ret=0;
if(u<=mid)ret=query(ls[x],ls[y],l,mid,u,v);
if(mid<v)ret+=query(rs[x],rs[y],mid+1,r,u,v);
return ret;
}
int main()
{
// freopen("in.txt","r",stdin);
n=read();
for(int i=1;i<=n;++i)
a[i]=read();
for(int i=n;i;--i)
nxt[i]=c[a[i]],c[a[i]]=i;
memset(c,0,sizeof(c));
for(int i=1;i<=n;++i)
pre[i]=c[a[i]],c[a[i]]=i;
for(int i=1;i<=n;++i)
{
modify(root[i],root[i-1],0,n,nxt[i]);
int tmp=query(root[i-1],root[pre[i]],0,n,i+1,n);
tmp+=query(root[i-1],root[pre[i]],0,n,0,0);
ans+=1ll*tmp;
}
printf("%lld\n",ans);
return 0;
}