题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4578
题目比较麻烦,取模要细心一点。
区间加,乘,覆盖问题不是很大,比较麻烦的就是求的和的部分,因为p只取1,2,3,所以还是有可操作性的。
假设只有3个数a, b ,c。
如果对这三个数做加的操作,那么 p = 2,有 (a+1)^2 + (b+1)^2 + (c+1)^2 。把这个式子的平方拆开整理下就能得到一个式子:
a^2 + b^2 + c^2 + 2*1*(a+b+c) + 2*1^2。所以如果在一个区间做加的操作,那么要想更新这个区间的平方和,只需要知道原区间的平方和与原区间的和。
p = 3也是一样的道理,把括号拆开,也能得到类似的结论。
对于区间乘的操作应该比较容易想到,这里就不说了。
题目总体上是不难的,如果上面的推论能想得到的话,就是一个简单的线段树题,但是代码写起来就需要足够细心和耐性,是一个非常麻烦的题目,wa了无数发。
对于线段树的区间更新,我有两种比较常用的方法,第一种是通过调取左区间和右区间的值,相加得到自己区间的值。第二种是通过记录子节点区间的更新值,把这个值直接通过 cur / 2 返回给父节点去更新自己的值,比如我给编号为 2的节点区间+x,那么修改完2这个节点的同时把这个x返回给父节点 1,让父节点也+x。下面的代码就是用的这种方法更新的数据,可能和大部分人写的不太一样。
还有一点,一开始写这题的时候,没怎么仔细的想,觉得乘,加的标记应该是不能共存的,也就是说一个节点有乘的标记又有加的标记,那么我更新这个节点的值很难知道是应该先乘还是先加。事实上是可以共存的,比如当前节点已经有加的标记了,当更新这个节点乘的标记时,不仅要更新乘的标记,同时还需要更新加的标记,比如说 a+3,如果对这个数做乘法,那么应该是 a*b + 3*b。这样一个节点就能同时存在乘标记和加标记,更新值时要先乘后加。
一定要注意取模爆int,最好变量全都设为 long long。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
using namespace std;
const int Maxn = 1e5+10;
const int mod = 10007;
typedef long long ll;
ll sum1[Maxn<<2], sum2[Maxn<<2], sum3[Maxn<<2], add[Maxn<<2],
mul[Maxn<<2], chan[Maxn<<2];
ll L, R, c, op;
void pushdown(ll cur, ll l, ll r) {
if(chan[cur]) {
sum1[cur] = chan[cur]*(r-l+1) % mod;
sum2[cur] = ((chan[cur]*chan[cur])%mod * (r-l+1)) % mod;
sum3[cur] = (((chan[cur]*chan[cur])%mod *chan[cur]) % mod * (r-l+1)) % mod;
if(l != r) {
chan[cur<<1] = chan[cur<<1|1] = chan[cur]; // 如果把覆盖的标记往下传,那么下传的标记如果有乘
mul[cur<<1] = add[cur<<1] = 0; // 或者加,一并清除掉。
mul[cur<<1|1] = add[cur<<1|1] = 0;
}
chan[cur] = 0;
}
if(mul[cur]) {
sum1[cur] = (sum1[cur]*mul[cur]) % mod;
sum2[cur] = ((mul[cur]*mul[cur])%mod * sum2[cur]) % mod;
sum3[cur] = (((mul[cur]*mul[cur])%mod * mul[cur]) %mod * sum3[cur]) % mod;
if(l != r) { // 对于乘的下传,不仅要更新乘的标记,加的也要更新
if(mul[cur<<1]) mul[cur<<1] = (mul[cur<<1] * mul[cur]) % mod;
else mul[cur<<1] = mul[cur];
if(mul[cur<<1|1]) mul[cur<<1|1] = (mul[cur<<1|1] * mul[cur]) % mod;
else mul[cur<<1|1] = mul[cur];
add[cur<<1] = (add[cur<<1] * mul[cur]) % mod;
add[cur<<1|1] = (add[cur<<1|1] * mul[cur]) % mod;
}
mul[cur] = 0;
}
if(add[cur]) {
ll len = (r-l+1);
sum3[cur] = (((3*add[cur]*add[cur])%mod * sum1[cur]) %mod + (3*add[cur]*sum2[cur])% mod
+ ((len*add[cur]*add[cur] % mod) * add[cur])%mod + sum3[cur]) % mod;
sum2[cur] = (((2*add[cur]*sum1[cur])%mod + (len*add[cur]*add[cur])%mod) % mod + sum2[cur]) % mod;
sum1[cur] = (sum1[cur]+(len*add[cur]))%mod;
if(l != r) {
add[cur<<1] = (add[cur<<1]+add[cur]) % mod;
add[cur<<1|1] = (add[cur<<1|1]+add[cur]) % mod;
}
add[cur] = 0;
}
}
void updata(ll cur, ll l, ll r) {
pushdown(cur, l, r); // 标记下传,这时更新的是旧的标记
ll tmp1 = sum1[cur], tmp2 = sum2[cur], tmp3 = sum3[cur]; // 用来记录原区间的数据, 当更新数据时,
if(L <= l && r <= R) { // 与更新的数据做差,把这个差值返回给父节点更新区间
if(op == 1) {
add[cur] = (add[cur]+c) % mod;
} else if(op == 2) {
add[cur] = add[cur]*c % mod;
if(mul[cur]) mul[cur] = (ll)mul[cur]*c % mod;
else mul[cur] = c;
} else {
chan[cur] = c;
add[cur] = mul[cur] = 0;
}
pushdown(cur, l, r); // 更新新的标记, 也就是上面代码处理的标记
sum1[cur/2] = (sum1[cur/2]+(sum1[cur]-tmp1)+mod) % mod; // 注意这里的+mod,如果没有取模操作,
sum2[cur/2] = (sum2[cur/2]+(sum2[cur]-tmp2)+mod) % mod; // 原则上是不可能有负值,如果是取模,
sum3[cur/2] = (sum3[cur/2]+(sum3[cur]-tmp3)+mod) % mod; // 那就有可能存在负值
return;
}
ll mid = (l+r)/2;
if(L <= mid) updata(cur<<1, l, mid);
if(mid+1 <= R) updata(cur<<1|1, mid+1, r);
sum1[cur/2] = (sum1[cur/2]+(sum1[cur]-tmp1)+mod) % mod;
sum2[cur/2] = (sum2[cur/2]+(sum2[cur]-tmp2)+mod) % mod;
sum3[cur/2] = (sum3[cur/2]+(sum3[cur]-tmp3)+mod) % mod;
}
ll solve(ll cur, ll l, ll r) {
pushdown(cur, l, r); // 标记下传
if(L <= l && r <= R) {
if(c == 1) return sum1[cur];
else if(c== 2) return sum2[cur];
else return sum3[cur];
}
ll mid = (r+l)/2, m1 = 0, m2 = 0;
if(L <= mid) m1 = solve(cur<<1, l, mid);
if(mid+1 <= R) m2 = solve(cur<<1|1, mid+1, r);
return (m1+m2) % mod;
}
int main(void)
{
int N, M;
while (scanf("%d%d", &N, &M) != EOF) {
if(N == 0 && M == 0) break;
memset(sum1, 0, sizeof(sum1));
memset(sum2, 0, sizeof(sum2));
memset(sum3, 0, sizeof(sum3));
memset(add, 0, sizeof(add));
memset(mul, 0, sizeof(mul));
memset(chan, 0, sizeof(chan));
while(M--) {
scanf("%I64d%I64d%I64d%I64d", &op, &L, &R, &c);
if(op == 4) printf("%I64d\n",solve(1, 1, N));
else updata(1, 1, N);
}
}
return 0;
}