题目大意
给定一个 1 ∼ n 1\sim n 1∼n 的排列 a 1 , a 2 , . . . , a n a_1, a_2, ...,a_n a1,a2,...,an ,求极大上升子序列(即不存在真包含它的上升子序列)的个数。( n ≤ 1 0 5 n\le 10^5 n≤105)
思路
首先有一个
O
(
n
2
)
O(n^2)
O(n2) 的做法:
记
f
j
f_j
fj 为
a
1
,
a
2
,
.
.
.
,
a
i
a_1, a_2, ..., a_i
a1,a2,...,ai 中以
a
i
a_i
ai 为结尾的极大上升子序列个数。
然后
f
j
f_j
fj 对
f
i
f_i
fi 有贡献(
f
i
+
=
f
j
f_i+=f_j
fi+=fj)当且仅当
a
j
<
a
i
a_j<a_i
aj<ai 且
∄
k
(
j
<
k
<
i
)
,
a
j
<
a
k
<
a
i
\not\exists k(j<k<i), a_j<a_k<a_i
∃k(j<k<i),aj<ak<ai。
最后设
a
n
+
1
=
n
+
1
a_{n+1}=n+1
an+1=n+1 ,则
f
n
+
1
f_{n+1}
fn+1 即为答案。
然后发现可以利用CDQ分治进行优化。
考虑
[
l
,
m
]
[l, m]
[l,m] 对
[
m
+
1
,
r
]
[m+1, r]
[m+1,r] 中各元素的贡献,可以从小到大考虑
[
l
,
r
]
[l, r]
[l,r] 中的各个数,并维护两个单调栈,
[
l
,
m
]
[l, m]
[l,m] 的元素下标递减,
[
m
+
1
,
r
]
[m+1, r]
[m+1,r] 的元素下标递增。
对
i
∈
[
m
+
1
,
r
]
i\in [m+1, r]
i∈[m+1,r] ,找到
[
m
+
1
,
r
]
[m+1, r]
[m+1,r] 栈中
a
i
a_i
ai 左侧最靠右的元素(即退栈后的栈顶)
a
j
a_j
aj。
然后找到
[
l
,
m
]
[l, m]
[l,m] 中第一个比
a
j
a_j
aj 大的和最后一个比
a
i
a_i
ai 小的(即栈顶)元素。
[
l
,
m
]
[l, m]
[l,m] 栈中它们之间的元素和即
[
l
,
m
]
[l, m]
[l,m] 对
f
i
f_i
fi 的贡献。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)
using namespace std;
const int inf = 0x3fffffff;
const int N = 100005;
const int mod = 998244353;
int T;
int n;
class node {
public:
int id, val;
bool operator<(const node &rhs) const { return val < rhs.val; }
} a[N];
int dp[N];
void init() {
int mx = inf;
rep(i, 1, n) {
if (a[i].val < mx) {
mx = a[i].val;
dp[i] = 1;
} else
dp[i] = 0;
}
}
void solve(int l, int r) {
if (l == r) return;
int mid = (l + r) / 2;
solve(l, mid);
sort(a + l, a + r + 1);
static node sta1[N], sta2[N];
static int top1, top2;
static int s[N];
top1 = top2 = 0;
rep(i, l, r) {
// printf("%d ", a[i].val);
if (a[i].id <= mid) {
while (top1 && sta1[top1].id < a[i].id) {
--top1;
}
++top1;
sta1[top1] = a[i];
s[top1] = (s[top1 - 1] + dp[a[i].id]) % mod;
} else {
while (top2 && sta2[top2].id > a[i].id) {
--top2;
}
int j = lower_bound(sta1 + 1, sta1 + top1 + 1, sta2[top2]) - sta1;
dp[a[i].id] = (1ll * dp[a[i].id] + s[top1] - s[j - 1] + mod) % mod;
++top2;
sta2[top2] = a[i];
}
}
// printf("\n");
sort(a + l, a + r + 1, [](node x, node y) { return x.id < y.id; });
solve(mid + 1, r);
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
rep(i, 1, n) {
a[i].id = i;
scanf("%d", &a[i].val);
}
++n;
a[n].id = a[n].val = n;
init();
solve(1, n);
printf("%d\n", dp[n]);
}
return 0;
}