Description
给定 010101 序列 a=(a1,a2,⋯ ,an)a=(a_1,a_2,\cdots,a_n)a=(a1,a2,⋯,an),并定义 f(l,r)=[(∑i=lrai)=r−l+1]f(l,r)=[(\sum\limits_{i=l}^r a_i)=r-l+1]f(l,r)=[(i=l∑rai)=r−l+1].
执行 mmm 个操作,分五种:
- reset(l,r)\operatorname{reset}(l,r)reset(l,r):对每个 i∈[l,r]i\in[l,r]i∈[l,r] 执行 ai←0a_i\gets 0ai←0.
- set(l,r)\operatorname{set}(l,r)set(l,r):对每个 i∈[l,r]i\in[l,r]i∈[l,r] 执行 ai←1a_i\gets 1ai←1.
- negate(l,r)\operatorname{negate}(l,r)negate(l,r):对每个 i∈[l,r]i\in[l,r]i∈[l,r] 执行 ai←1−aia_i\gets 1-a_iai←1−ai.
- sum(l,r)\operatorname{sum}(l,r)sum(l,r):求 ∑i=lrai\sum\limits_{i=l}^r a_ii=l∑rai.
- gss(l,r)\operatorname{gss}(l,r)gss(l,r):求 max[s,t]∈[l,r](t−s+1)×f(s,t)\max\limits_{[s,t]\in[l,r]}(t-s+1)\times f(s,t)[s,t]∈[l,r]max(t−s+1)×f(s,t).
Limitations
1≤n,m≤1051\le n,m\le 10^51≤n,m≤105
1≤l≤r≤n1\le l\le r\le n1≤l≤r≤n
1s,128MB1\text{s},128\text{MB}1s,128MB
Solution
线段树做法.
节点上维护 S=(sum,lsum,rsum,tsum,len)S=(\textit{sum},\textit{lsum},\textit{rsum},\textit{tsum},\textit{len})S=(sum,lsum,rsum,tsum,len),用类似最大子段和的方法更新.
由于要取反,000 和 111 都分别需要一个 SSS(下记为 S0S_0S0 和 S1S_1S1) .
然后需要标记 TTT,显然 T=(cov,rev)T=(\textit{cov},\textit{rev})T=(cov,rev)(分别为赋值和取反标记)
- 考虑赋值,我们直接将 ScovS_{cov}Scov 全部填上 len\textit{len}len,S1−covS_{1-cov}S1−cov 全部清零(len\textit{len}len 要不变)
- 考虑翻转,我们直接交换 S0S_0S0 和 S1S_1S1.
然后需要考虑 T+TT+TT+T:
- 赋值可以直接打标记.
- 对于取反,当一个区间赋值后,取反就相当于赋值 (1−cov)(1-\textit{cov})(1−cov),那么我们可以将 cov\textit{cov}cov 取反(此时 rev\textit{rev}rev 没用,要置为 000),否则,直接更新 rev\textit{rev}rev.
有不少坑点:
- 下标从 000 开始.
- 没有被赋值的节点,cov\textit{cov}cov 要置为 −1-1−1 表示没有赋值.
- 更新 SSS 时,左半区间满了,lsum\textit{lsum}lsum 才能跨越中点,rsum\textit{rsum}rsum 同理.
- 先处理 cov\textit{cov}cov 再处理 rev\textit{rev}rev.
- 打 rev\textit{rev}rev 标记时
^= 1不要写成= 1.
Code
3.9KB,0.45s,11.63MB (in total, C++20 with O2)3.9\text{KB},0.45\text{s},11.63\text{MB}\;\texttt{(in total, C++20 with O2)}3.9KB,0.45s,11.63MB(in total, C++20 with O2)
建议封装 SSS(见代码中的 Data),会好写不少.
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
using ui128 = unsigned __int128;
using f4 = float;
using f8 = double;
using f16 = long double;
template<class T>
bool chmax(T &a, const T &b){
if(a < b){ a = b; return true; }
return false;
}
template<class T>
bool chmin(T &a, const T &b){
if(a > b){ a = b; return true; }
return false;
}
namespace seg_tree {
struct Data {
int sum, lsum, rsum, tsum, len;
inline Data() {}
inline Data(int x) : sum(x), lsum(x), rsum(x), tsum(x), len(1) {}
inline Data(int _sum, int _lsum, int _rsum, int _tsum, int _len)
: sum(_sum), lsum(_lsum), rsum(_rsum), tsum(_tsum), len(_len) {}
};
inline Data operator+(const Data& lhs, const Data& rhs) {
const int _sum = lhs.sum + rhs.sum;
const int _lsum = lhs.lsum + (lhs.sum == lhs.len) * rhs.lsum;
const int _rsum = rhs.rsum + (rhs.sum == rhs.len) * lhs.rsum;
const int _tsum = std::max({lhs.tsum, rhs.tsum, lhs.rsum + rhs.lsum});
const int _len = lhs.len + rhs.len;
return Data(_sum, _lsum, _rsum, _tsum, _len);
}
struct Node {
int l, r;
array<Data, 2> dat;
int cov, rev;
inline Data& operator[](int i) { return dat[i]; }
inline Data operator[](int i) const { return dat[i]; }
};
inline int ls(int u) { return 2 * u + 1; }
inline int rs(int u) { return 2 * u + 2; }
struct SegTree {
vector<Node> tr;
inline SegTree() {}
inline SegTree(const vector<int>& a) {
const int n = a.size();
tr.resize(n << 1);
build(0, 0, n - 1, a);
}
inline void pushup(int u, int mid) {
for (int i = 0; i < 2; i++) tr[u][i] = tr[ls(mid)][i] + tr[rs(mid)][i];
}
inline void apply(int u, int cov, int rev) {
if (~cov) {
const int len = tr[u][cov].len;
tr[u][cov] = Data(len, len, len, len, len);
tr[u][cov ^ 1] = Data(0, 0, 0, 0, len);
tr[u].cov = cov;
tr[u].rev = 0;
return;
}
if (rev) {
swap(tr[u][0], tr[u][1]);
if (~tr[u].cov) tr[u].cov ^= 1;
else tr[u].rev ^= 1;
}
}
inline void pushdown(int u, int mid) {
apply(ls(mid), tr[u].cov, tr[u].rev);
apply(rs(mid), tr[u].cov, tr[u].rev);
tr[u].cov = -1, tr[u].rev = 0;
}
void build(int u, int l, int r, const vector<int>& a) {
tr[u].l = l, tr[u].r = r, tr[u].cov = -1;
if (l == r) {
for (int i = 0; i < 2; i++) tr[u][i] = Data(a[l] == i);
return;
}
const int mid = (l + r) >> 1;
build(ls(mid), l, mid, a);
build(rs(mid), mid + 1, r, a);
pushup(u, mid);
}
void update(int u, int l, int r, int cov, int rev) {
if (l <= tr[u].l && tr[u].r <= r) return apply(u, cov, rev);
const int mid = (tr[u].l + tr[u].r) >> 1;
pushdown(u, mid);
if (l <= mid) update(ls(mid), l, r, cov, rev);
if (r > mid) update(rs(mid), l, r, cov, rev);
pushup(u, mid);
}
Data query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u][1];
const int mid = (tr[u].l + tr[u].r) >> 1;
pushdown(u, mid);
if (r <= mid) return query(ls(mid), l, r);
else if (l > mid) return query(rs(mid), l, r);
else return query(ls(mid), l, r) + query(rs(mid), l, r);
}
inline void range_cover(int l, int r, int v) { update(0, l, r, v, 0); }
inline void range_negate(int l, int r) { update(0, l, r, -1, 1); }
inline int range_sum(int l, int r) { return query(0, l, r).sum; }
inline int range_gss(int l, int r) { return query(0, l, r).tsum; }
};
}
using seg_tree::SegTree;
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int n, m;
scanf("%d %d", &n, &m);
vector<int> a(n);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
SegTree sgt(a);
for (int i = 0, op, l, r; i < m; i++) {
scanf("%d %d %d", &op, &l, &r);
if (op == 0) sgt.range_cover(l, r, 0);
if (op == 1) sgt.range_cover(l, r, 1);
if (op == 2) sgt.range_negate(l, r);
if (op == 3) printf("%d\n", sgt.range_sum(l, r));
if (op == 4) printf("%d\n", sgt.range_gss(l, r));
}
return 0;
}
154

被折叠的 条评论
为什么被折叠?



