思路:
关键在于乘与加的先后计算关系,(x + y) * k = x * k + y * k,从这里可以看出来,把加法转化为乘法计算,取消了+与*先后顺序
pushdown时,即为乘法标记 * 原有数据 + 加法标记 * 长度。
注意点:
- 这个题数据范围取long long
- 读入的k也是long long,传入函数时用long long
- 个人wa点:pushdown时,暂存mul和add时,用了int
#include <iostream>
using namespace std;
#define Ln(x) (x << 1)
#define Rn(x) (x << 1) | 1
#define endl "\n"
#define ll long long
const int maxn = 1e5+5;
struct Tree{
int l, r;
ll data, mul, add;
Tree(int l=0, int r=0, ll data=0):
l(l), r(r), data(data) {mul = 1, add = 0;};
}tree[maxn*4+10];
ll a[maxn];
ll mod;
void pushup(int p){
tree[p].data = (tree[Ln(p)].data + tree[Rn(p)].data) % mod;
}
void build(int l, int r, int p){
tree[p] = Tree(l, r, 0);
if(l == r){
tree[p] = Tree(l, r, a[l]);
return;
}
int mid = (l + r) >> 1;
build(l, mid, Ln(p));
build(mid+1, r, Rn(p));
pushup(p);
}
void pushdown(int p){
ll mul = tree[p].mul;
ll add = tree[p].add;
tree[Ln(p)].mul = (mul * tree[Ln(p)].mul) % mod;
tree[Ln(p)].add = (tree[Ln(p)].add * mul + add) % mod;
tree[Rn(p)].mul = (mul * tree[Rn(p)].mul) % mod;
tree[Rn(p)].add = (tree[Rn(p)].add * mul + add) % mod;
//乘法优先
tree[Ln(p)].data = (tree[Ln(p)].data * mul + add * (tree[Ln(p)].r - tree[Ln(p)].l + 1)) % mod;
tree[Rn(p)].data = (tree[Rn(p)].data * mul + add * (tree[Rn(p)].r - tree[Rn(p)].l + 1)) % mod;
tree[p].mul = 1;
tree[p].add = 0;
}
void update(int l, int r, int p, int op, ll d){
int nl = tree[p].l, nr = tree[p].r;
if(nl >= l && nr <= r){
if(op == 1){
tree[p].mul = tree[p].mul * d % mod;
tree[p].data = tree[p].data * d % mod;
tree[p].add = tree[p].add * d % mod; //乘法优先,惩罚转化为加法
}
else{
tree[p].add = (tree[p].add + d) % mod;
tree[p].data = (d * (nr - nl + 1) + tree[p].data) % mod;
}
return;
}
if(tree[p].mul != 1 || tree[p].add) pushdown(p);
int mid = (nl+nr) >> 1;
if (l <= mid) update(l, r, Ln(p), op, d);
if (r > mid) update(l, r, Rn(p), op, d);
pushup(p);
}
ll query(int l, int r, int p){
int nl = tree[p].l, nr = tree[p].r;
if(nl >= l && nr <= r)
return tree[p].data;
//cout << "#" << tree[1].lazy << " " << p << endl;
if(tree[p].mul != 1 || tree[p].add) pushdown(p);
int mid = (nl + nr) >> 1;
ll res = 0;
if(l <= mid) res = (res + query(l, r, Ln(p))) % mod;
if(r > mid) res = (res + query(l, r, Rn(p))) % mod;
return res;
}
int main()
{
int n, m;
cin >> n >> m >> mod;
for (int i=1; i<=n; ++i)
cin >> a[i];
build(1, n, 1);
while(m --){
int op;
cin >> op;
if(op != 3){
int l, r;
ll d;
cin >> l >> r >> d;
update(l, r, 1, op, d);
}
else{
int l, r;
cin >> l >> r;
cout << query(l, r, 1) << endl;
}
}
return 0;
}