题意
给定一个序列
A
=
(
A
1
,
A
2
,
⋯
,
A
n
)
A=(A_1,A_2,\cdots,A_n)
A=(A1,A2,⋯,An)。
现在进行
m
m
m次操作,分为以下两种:
1 l r k d
:给定一个长度为 r − l + 1 r-l+1 r−l+1的等差序列,首项为 k k k,公差为 d d d,并将它对应加到 [ l , r ] [l,r] [l,r]范围中的每一个数上。2 x
:查询 A x A_x Ax的值。
思路
将序列
A
A
A进行差分,记差分数组为
B
B
B。
接下来,如果要加上一个等差序列,则要进行以下操作:
- B l = B l + k B_l=B_l+k Bl=Bl+k
- B i = B i + d ( i ∈ [ l , r ] ) B_i=B_i+d \quad (i \in [l, r]) Bi=Bi+d(i∈[l,r])
- B r + 1 = B r + 1 − ( k + d × ( r − l ) ) B_{r+1}=B_{r+1}-(k + d \times (r - l)) Br+1=Br+1−(k+d×(r−l))
如果不懂,看看下面的例子:
原序列:
A
=
(
0
,
0
,
0
,
0
,
0
,
0
)
A=(0,0,0,0,0,0)
A=(0,0,0,0,0,0)
差分序列:
B
=
(
0
,
0
,
0
,
0
,
0
,
0
)
B=(0,0,0,0,0,0)
B=(0,0,0,0,0,0)
等差序列:
C
=
(
1
,
3
,
5
,
7
,
9
)
C=(1,3,5,7,9)
C=(1,3,5,7,9)
现序列:
A
=
(
1
,
3
,
5
,
7
,
9
,
0
)
A=(1,3,5,7,9,0)
A=(1,3,5,7,9,0)
差分序列:
B
=
(
1
,
2
,
2
,
2
,
2
,
−
9
)
B=(1,2,2,2,2,-9)
B=(1,2,2,2,2,−9)
如果要查询
A
p
A_p
Ap,就输出
∑
i
=
1
p
B
i
\sum_{i=1}^{p} B_i
∑i=1pBi。
执行以上操作,只需要一个支持单点修改、区间修改。区间查询的线段树即可。
最后,注意
l
=
r
l=r
l=r或者
r
=
n
r=n
r=n的情况.
代码
#include <iostream>
#include <vector>
using namespace std;
#define int long long
struct segment {
#define ls (u << 1)
#define rs (u << 1 | 1)
struct Node {
int l, r, sum = 0, add = 0;
};
vector<Node> tr;
segment(vector<int> &a) {
int n = a.size();
tr.resize(n << 2);
build(1, 1, n, a);
}
void pushup(int u) {
tr[u].sum = tr[ls].sum + tr[rs].sum;;
}
void build(int u, int l, int r, vector<int> &a) {
tr[u].l = l; tr[u].r = r;
if (l == r) {
tr[u].sum = a[l - 1];
return;
}
int mid = l + r >> 1;
build(ls, l, mid, a);
build(rs, mid + 1, r, a);
pushup(u);
}
void pushdown(int u) {
if (tr[u].add) {
tr[ls].sum += tr[u].add * (tr[ls].r - tr[ls].l + 1);
tr[rs].sum += tr[u].add * (tr[rs].r - tr[rs].l + 1);
tr[ls].add += tr[u].add;
tr[rs].add += tr[u].add;
tr[u].add = 0;
}
}
void modify(int u, int l, int r, int v) {
if (l <= tr[u].l && tr[u].r <= r) {
tr[u].sum += v * (tr[u].r - tr[u].l + 1);
tr[u].add += v;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(ls, l, r, v);
if(r > mid) modify(rs, l, r, v);
pushup(u);
}
int query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int ans = 0;
if(l <= mid) ans += query(ls, l, r);
if(r > mid) ans += query(rs, l, r);
return ans;
}
#undef ls
#undef rs
};
signed main(){
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int n, m;
cin >> n >> m;
vector<int> a(n);
for (auto &i: a) cin >> i;
vector<int> diff(n);
diff[0] = a[0];
for (int i = 1; i < n; i++) diff[i] = a[i] - a[i - 1];
segment seg(diff);
for (int i = 0; i < m; i++) {
int op;
cin >> op;
if (op == 1) {
int l, r, k, d;
cin >> l >> r >> k >> d;
seg.modify(1, l, l, k);
if(l + 1 <= r) seg.modify(1, l + 1, r, d);
if(r < n) seg.modify(1, r + 1, r + 1, -(k + d * (r - l)));
} else {
int p;
cin >> p;
cout << seg.query(1, 1, p) << endl;
}
}
return 0;
}