题目链接
https://atcoder.jp/contests/abc397/tasks/abc397_f
思路
令 L [ i ] L[i] L[i]表示区间 [ 1 , i ] [1,i] [1,i]中有多少个不同的数。
令 R [ i ] R[i] R[i]表示区间 [ i , n ] [i,n] [i,n]中有多少个不同的数。
令 n x t [ i ] nxt[i] nxt[i]表示下一个与 a [ i ] a[i] a[i]的值相同的数的下标。
假设我们现在已经枚举了第一个区间与第二个区间的分界点 i i i, i i i表示第二个区间的起始端点。则第一个区间 [ 1 , i ] [1,i] [1,i]对答案产生的贡献为 L [ i ] L[i] L[i]。
考虑第二个分界点 j j j( j j j表示第二个区间的终止端点)。
在区间 [ i , n ] [i,n] [i,n]中的数字,有两种情况:
- 1,这个数字只出现了一次,无论放到第二个区间还是第三个区间,都只会产生 1 1 1的贡献。
- 2,这个数字出现了多次,那么如果 j j j选在了这个数字出现两次的中间,就会让 [ i , j ] [i,j] [i,j]和 [ j + 1 , n ] [j+1,n] [j+1,n]都产生 1 1 1的贡献。
所以,对于区间 [ i , n ] [i,n] [i,n]中的数字,至少产生 1 1 1个贡献,还有可能多产生 1 1 1个贡献。
所以,最终的答案就是 m a x ( L [ i − 1 ] + R [ i ] + 多产生的贡献 ) max(L[i-1]+R[i]+多产生的贡献) max(L[i−1]+R[i]+多产生的贡献)。
现在开始考虑如何计算贡献:
对于区间 [ i , n ] [i,n] [i,n]上的下标为 k k k的数字,如果其出现了多次,则下一次出现的下标为 n x t [ k ] nxt[k] nxt[k]。如果第二个区间与第三个区间的分断点 j j j出现在 k k k和 n x t [ k ] nxt[k] nxt[k]之间,则 a [ k ] a[k] a[k]这个数字会额外产生 1 1 1个贡献。
因此,对于额外产生的贡献,我们可以用线段树实现区间加法,并维护区间最大值。
时间复杂度: O ( n l o g n ) O(nlogn) O(nlogn)。
代码
// #pragma GCC optimize("O2")
// #pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define double long double
#define endl "\n"
typedef long long i64;
typedef unsigned long long u64;
typedef pair<int, int> pii;
const int N = 3e5 + 5, M = 1e6 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-6;
std::mt19937 rnd(time(0));
int n;
int a[N], L[N], R[N], nxt[N];
struct SegmentTree
{
struct node
{
int l, r, maxx, tag;
};
vector<node>tree;
SegmentTree() {}
SegmentTree(int n) {tree.resize(n * 4 + 1);}
void pushup(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
root.maxx = max(left.maxx , right.maxx);
}
void pushdown(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
if (root.tag)
{
left.tag += root.tag;
right.tag += root.tag;
left.maxx += root.tag;
right.maxx += root.tag;
root.tag = 0;
}
}
void build(int u, int l, int r)
{
auto &root = tree[u];
root = {l, r};
if (l == r)
{
root.maxx = 0;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int val)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
root.maxx += val;
root.tag += val;
return;
}
pushdown(u);
int mid = root.l + root.r >> 1;
if (l <= mid) modify(u << 1, l, r, val);
if (r > mid) modify(u << 1 | 1, l, r, val);
pushup(u);
}
int query(int u, int l, int r)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
return root.maxx;
}
pushdown(u);
int mid = root.l + root.r >> 1;
int res = 0;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
};
void solve(int test_case)
{
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
}
set<int>st1, st2;
for (int i = 1, j = n; i <= n; i++, j--)
{
st1.insert(a[i]), st2.insert(a[j]);
L[i] = st1.size(), R[j] = st2.size();
}
vector<int>mp(n + 1, 0);
for (int i = n; i >= 1; i--)
{
nxt[i] = mp[a[i]];
mp[a[i]] = i;
}
SegmentTree smt(n);
smt.build(1, 1, n);
for (int i = n; i >= 2; i--)
{
if (nxt[i] > i)
{
smt.modify(1, i, nxt[i] - 1, 1);
}
}
int ans = 0;
for (int i = 2; i < n; i++)
{
int res = L[i - 1] + R[i];
res += smt.query(1, i, n - 1);
ans = max(ans, res);
if (nxt[i] > i)
{
smt.modify(1, i, nxt[i] - 1, -1);
}
}
cout << ans << endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int test = 1;
// cin >> test;
for (int i = 1; i <= test ; i++)
{
solve(i);
}
return 0;
}