题目链接
Educational Codeforces Round 166 D. Invertible Bracket Sequences
思路
我们将 ‘ ( ’ ‘(’ ‘(’当作 1 1 1, ‘ ) ’ ‘)’ ‘)’当作 − 1 -1 −1,之后预处理出字符串的前缀和数组 s u m sum sum。
想要反转一段区间之后的字符串仍然为正则括号序列,必须满足: 被反转的区间内,两种括号的数量相等。否则,反转之后的字符串绝无可能是正则括号序列。
我们假设被反转区间的左边的区间的区间和为 o p 1 op1 op1,被反转区间的区间和为 o p 2 op2 op2。如果 o p 2 ≤ o p 1 op2 \le op1 op2≤op1,则该区间反转之后的字符串仍然为正则括号序列。因为区间被反转之后, o p 2 op2 op2的值就变为了 − o p 2 -op2 −op2,而想要字符串为正则括号序列,则所有的 s u m [ i ] sum[i] sum[i]都必须有 s u m [ i ] ≥ 0 sum[i] \ge 0 sum[i]≥0。
因此我们可以使用线段树来维护 s u m sum sum数组的区间最大值(因为是静态的,用ST表也可以)。用map< i n t int int,vector< i n t int int>>来映射每一个 s u m [ i ] sum[i] sum[i]的值所对应的数组下标。
枚举每一个被反转区间的左端点 i i i,因为 [ i , i : n ] [i,i:n] [i,i:n]的区间最大值具有单调不减的性质,因此区间右端点最大的值 j j j可以使用二分查找求得。
最后,在 m a p [ s u m [ i − 1 ] ] map[sum[i-1]] map[sum[i−1]]上使用二分来求得满足条件的右端点的个数即可。
代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
typedef pair<int, int> pii;
const int N = 2e5 + 5, M = 1e6 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f3f3f3f3f;
int n;
int a[N], sum[N];
string s;
struct segmenttree
{
struct node
{
int l, r, maxx, tag;
};
vector<node>tree;
segmenttree(): tree(1) {}
segmenttree(int n): tree(n * 4 + 1) {}
void pushup(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
root.maxx = max(left.maxx, right.maxx);
}
void pushdown(int u)
{
auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
if (root.tag != 0)
{
left.tag += root.tag;
right.tag += root.tag;
left.maxx = left.maxx + root.tag;
right.maxx = right.maxx + root.tag;
root.tag = 0;
}
}
void build(int u, int l, int r)
{
auto &root = tree[u];
root = {l, r};
if (l == r)
{
root.maxx = sum[r];
}
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int val)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
root.maxx += val;
root.tag += val;
return;
}
pushdown(u);
int mid = root.l + root.r >> 1;
if (l <= mid) modify(u << 1, l, r, val);
if (r > mid) modify(u << 1 | 1, l, r, val);
pushup(u);
}
int query(int u, int l, int r)
{
auto &root = tree[u];
if (root.l >= l && root.r <= r)
{
return root.maxx;
}
pushdown(u);
int mid = root.l + root.r >> 1;
int res = -inf;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
};
void solve()
{
cin >> s;
n = s.size();
for (int i = 0; i < n; i++)
{
a[i + 1] = (s[i] == '(' ? 1 : -1);
}
map<int, vector<int>>mp;
mp[0].push_back(0);//防止后面二分时越界
for (int i = 1; i <= n; i++)
{
sum[i] = sum[i - 1] + a[i];
mp[sum[i]].push_back(i);
}
segmenttree smt(n);
smt.build(1, 1, n);
int ans = 0;
for (int i = 1; i <= n; i++)
{
int res = sum[i - 1];
int low = i, high = n;
while (low < high)
{
int mid = low + high + 1 >> 1;
auto check = [&](auto check, int x)->bool{
int maxx = smt.query(1, i, x);
maxx -= sum[i - 1];
return maxx <= res;
};
if (check(check, mid))
{
low = mid;
}
else high = mid - 1;
}
if (!mp.count(sum[i - 1])) continue;
int R = upper_bound(mp[sum[i - 1]].begin(), mp[sum[i - 1]].end(), low) - mp[sum[i - 1]].begin();
R--;
int L = lower_bound(mp[sum[i - 1]].begin(), mp[sum[i - 1]].end(), i) - mp[sum[i - 1]].begin();
if (L <= R)
{
ans += (R - L + 1);
}
}
cout << ans << endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int test = 1;
cin >> test;
for (int i = 1; i <= test; i++)
{
solve();
}
return 0;
}