题意:有4种操作,1 x y c 给从x到y的所有数加上c,2 x y c 给从x到y的所有数乘上c,3 x y c把从x到y的所有数变成c,4 x y p,计算从x到y的p次方的和,x^p + (x+1)^p + … + y^p (p = 1 2 3)
思路:sum1,sum2,sum3代表p=1,p=2,p=3,时候区间的值,这三个值是可以直接推出来的。关于加和乘的问题,如果直接搞是有问题的,所以变成区间内的和先乘以一个数,再加上一个数的形式,注意,在乘的时候要对加数的lazy进行处理。在第三种操作的时候要取消前两种操作的之前留下的影响。
坑点:wa了之后百思不得其解,直到心一横,所有的都改为long long类型,才过。
http://acm.hdu.edu.cn/showproblem.php?pid=4578
#include <cstdio>
#include <cmath>
#include <iostream>
#define dist(rt) (t[rt].y - t[rt].x + 1)
using namespace std
typedef long long LL
const int MAXN = 1e6+5
LL MOD = 10007
struct Node {
LL x,y
LL add,mul,set
LL sum1,sum2,sum3
}t[MAXN<<2]
LL n,m
LL tmp_sum1,tmp_sum2
void Push_Up(LL rt) {
t[rt].sum1 = (t[rt<<1].sum1 + t[rt<<1|1].sum1)%MOD
t[rt].sum2 = (t[rt<<1].sum2 + t[rt<<1|1].sum2)%MOD
t[rt].sum3 = (t[rt<<1].sum3 + t[rt<<1|1].sum3)%MOD
}
void Push_Down(LL rt) {
if(t[rt].set) {
LL u = t[rt].set
t[rt<<1].sum1 = (dist(rt<<1) * u) % MOD
t[rt<<1].sum2 = (dist(rt<<1) * u * u) % MOD
t[rt<<1].sum3 = (dist(rt<<1) * u * u * u) % MOD
t[rt<<1|1].sum1 = (dist(rt<<1|1) * u) % MOD
t[rt<<1|1].sum2 = (dist(rt<<1|1) * u * u) % MOD
t[rt<<1|1].sum3 = (dist(rt<<1|1) * u * u * u) % MOD
t[rt<<1].set = t[rt<<1|1].set = t[rt].set
t[rt<<1].add = t[rt<<1|1].add = 0
t[rt<<1].mul = t[rt<<1|1].mul = 1
t[rt].set = 0
}
if(t[rt].mul > 1) {
LL u = t[rt].mul
t[rt<<1].sum1 = (t[rt<<1].sum1 * u) % MOD
t[rt<<1].sum2 = (t[rt<<1].sum2 * u * u) % MOD
t[rt<<1].sum3 = (t[rt<<1].sum3 * u * u * u) % MOD
t[rt<<1|1].sum1 = (t[rt<<1|1].sum1 * u) % MOD
t[rt<<1|1].sum2 = (t[rt<<1|1].sum2 * u * u) % MOD
t[rt<<1|1].sum3 = (t[rt<<1|1].sum3 * u * u * u) % MOD
t[rt<<1].mul = (t[rt<<1].mul * u)%MOD
t[rt<<1|1].mul = (t[rt<<1|1].mul * u)%MOD
t[rt<<1].add = (t[rt<<1].add * u)%MOD
t[rt<<1|1].add = (t[rt<<1|1].add * u)%MOD
t[rt].mul = 1
}
if(t[rt].add) {
LL u = t[rt].add
tmp_sum1 = t[rt<<1].sum1
tmp_sum2 = t[rt<<1].sum2
t[rt<<1].sum1 = (t[rt<<1].sum1 + dist(rt<<1) * u)%MOD
t[rt<<1].sum2 = (t[rt<<1].sum2 + 2*u*tmp_sum1 + u*u*dist(rt<<1))%MOD
t[rt<<1].sum3 = (t[rt<<1].sum3 + 3*u*tmp_sum2 + 3*u*u*tmp_sum1 + u*u*u*dist(rt<<1)) % MOD
tmp_sum1 = t[rt<<1|1].sum1
tmp_sum2 = t[rt<<1|1].sum2
t[rt<<1|1].sum1 = (t[rt<<1|1].sum1 + dist(rt<<1|1) * u)%MOD
t[rt<<1|1].sum2 = (t[rt<<1|1].sum2 + 2*u*tmp_sum1 + u*u*dist(rt<<1|1))%MOD
t[rt<<1|1].sum3 = (t[rt<<1|1].sum3 + 3*u*tmp_sum2 + 3*u*u*tmp_sum1 + u*u*u*dist(rt<<1|1)) % MOD
t[rt<<1].add = (t[rt<<1].add + u)%MOD
t[rt<<1|1].add = (t[rt<<1|1].add + u)%MOD
t[rt].add = 0
}
}
void Build(LL x,LL y,LL rt) {
t[rt].x = (LL)x
t[rt].add = t[rt].set = 0
t[rt].mul = 1
if(x == y) {
t[rt].sum1 = t[rt].sum2 = t[rt].sum3 = 0
return
}
LL mid = (x + y) >> 1
Build(x,mid,rt<<1)
Build(mid+1,y,rt<<1|1)
Push_Up(rt)
}
void Update(LL rt,LL left,LL right,LL u,LL flag) {
if(t[rt].x >= left && t[rt].y <= right) {
if(flag == 1) {
t[rt].add = (t[rt].add + u)%MOD
tmp_sum1 = t[rt].sum1
tmp_sum2 = t[rt].sum2
t[rt].sum1 = (t[rt].sum1 + dist(rt) * u)%MOD
t[rt].sum2 = (t[rt].sum2 + 2*u*tmp_sum1 + u*u*dist(rt))%MOD
t[rt].sum3 = (t[rt].sum3 + 3*u*tmp_sum2 + 3*u*u*tmp_sum1 + u*u*u*dist(rt)) % MOD
}
else if(flag == 2) {
t[rt].mul = (t[rt].mul * u) % MOD
t[rt].add = (t[rt].add * u) % MOD
t[rt].sum1 = (t[rt].sum1 * u) % MOD
t[rt].sum2 = (t[rt].sum2 * u * u) % MOD
t[rt].sum3 = (t[rt].sum3 * u * u * u) % MOD
}
else {
t[rt].mul = 1
t[rt].add = 0
t[rt].set = u
t[rt].sum1 = (dist(rt) * u)%MOD
t[rt].sum2 = (dist(rt) * u * u)%MOD
t[rt].sum3 = (dist(rt) * u * u * u)%MOD
}
return
}
LL mid = (t[rt].x + t[rt].y) >> 1
Push_Down(rt)
if(mid >= left) Update(rt<<1,left,right,u,flag)
if(mid < right) Update(rt<<1|1,left,right,u,flag)
Push_Up(rt)
}
LL sum1,sum2,sum3
void Query(LL rt,LL left,LL right) {
if(t[rt].x >= (LL)left && t[rt].y <= (LL)right) {
sum1 = (sum1 + t[rt].sum1)%MOD
sum2 = (sum2 + t[rt].sum2)%MOD
sum3 = (sum3 + t[rt].sum3)%MOD
return
}
LL mid = (t[rt].x + t[rt].y) >> 1
Push_Down(rt)
if(mid >= left) Query(rt<<1,left,right)
if(mid < right) Query(rt<<1|1,left,right)
Push_Up(rt)
}
void Print_Tree(LL rt) {
printf("%I64d %I64d %I64d\n",t[rt].x,t[rt].y,t[rt].sum2)
if(t[rt].x == t[rt].y) {
//printf("%d ",t[rt].sum2)
return
}
Push_Down(rt)
Print_Tree(rt<<1)
Print_Tree(rt<<1|1)
}
void input() {
Build(1,n,1)
LL x,y,u,ok
for(LL i = 1
// Print_Tree(1)
// puts("")
scanf("%I64d %I64d %I64d %I64d",&ok,&x,&y,&u)
if(ok <= 3) {
Update(1,x,y,u,ok)
}
else {
sum1 = sum2 = sum3 = 0
Query(1,x,y)
if(u == 1) printf("%I64d\n",sum1)
else if(u == 2) printf("%I64d\n",sum2)
else printf("%I64d\n",sum3)
}
}
}
void solve() {
}
int main(void) {
//freopen("a.in","r",stdin)
while(scanf("%I64d %I64d",&n,&m),n+m) {
input()
solve()
}
}