给出一棵n个节点的树,每个节点都有一种颜色c[i]。
q次询问[l,r]路径上的点的权值和vjwi。每一个节点的权值就是当前节点的颜色值v[c[i]]和从l到该节点经过该颜色的次数num[c[i]]对应的w值w[num[c[i]]之间的乘积(很绕,看题目描述吧,懒得描述了)
这是树上莫队。分块是对子树进行。如果一个某一个子树的size满足块大小,则分成一块。
void dfs1(int t, int f){
sz[t]=1;
fa[t] = f;
dp[t] = dp[f]+1;
dfn[t]=++clk;
st.push(t);
int cur_p = st.size();
for(int v : g[t]){
if(v==f)continue;
dfs1(v,t);
if(sz[v] > sz[hs[t]]) hs[t] = v;
sz[t]+=sz[v];
if(st.size() - cur_p >= block_size){
//get one block
while(st.size() > cur_p){
block_num[st.top()] = cur_block;
st.pop();
}
cur_block++;
}
}
if(t==1){
while(!st.empty()){
block_num[st.top()] = cur_block;
st.pop();
}
cur_block++;
}
}
最后就是正常的莫队操作了。
树上莫队的难点在于[l,r]的边界移动。
假设从[l,r]移动到[l',r'], 将l移动到l'或者r移动到r'时,对树上的标记取反(异或)来判断是否加入或者删除。
但是这里会有重复操作,见https://www.luogu.com.cn/blog/dedicatus545/solution-p4074
因此需要当l走到l'的时候,lca(l,l')是不操作的。r到r'同理。
void move(int old_x, int new_x){
if(dp[old_x] < dp[new_x]) swap(old_x,new_x);
while(dp[old_x]!=dp[new_x]){
change(old_x);
old_x = fa[old_x];
}
while(old_x != new_x){
change(old_x);
change(new_x);
old_x = fa[old_x];
new_x = fa[new_x];
}
}
最后将(l,r)的lca和[l',r']的lca取反保证他们只操作一次。
for(int i = 0; i < q.size();++i){
if(dfn[q[i].l]>dfn[q[i].r]) swap(q[i].l,q[i].r);
int f = lca(l,r);
change(f);
while(cur<q[i].t){
cur++;
int pos = modi[cur].pos;
int old_c = modi[cur].old_c;
int new_c = modi[cur].c;
change(pos, old_c, new_c);
}
while(cur>q[i].t){
int pos = modi[cur].pos;
int old_c = modi[cur].old_c;
int new_c = modi[cur].c;
change(pos, new_c,old_c);
cur--;
}
if(l!=q[i].l) {
move(l,q[i].l);
l = q[i].l;
}
if(r!=q[i].r){
move(r,q[i].r);
r = q[i].r;
}
f = lca(l,r);
change(f);
ans[q[i].id] = sum;
}
另外(l,r)需要保证他们的dfs序。
完整代码
int block_size;
struct query{
int l,r,id;
int block;
int r_block;
int t;
bool operator<(const query &q) const{
if(block != q.block) return block<q.block;
if(r_block == q.r_block) return id < q.id;
return r_block<q.r_block;
}
}q[N];
struct Revise{
int pos, c, old_c;
};
int v[N],w[N];
int c[N], num[N],old[N];
ll ans[N];
ll sum;
stack<int> st;
int cur_block,clk;
int hs[N],sz[N], block_num[N];
int fa[N],top[N],dp[N],vis[N],dfn[N];
vector<int> g[N];
void dfs1(int t, int f){
sz[t]=1;
fa[t] = f;
dp[t] = dp[f]+1;
dfn[t]=++clk;
st.push(t);
int cur_p = st.size();
for(int v : g[t]){
if(v==f)continue;
dfs1(v,t);
if(sz[v] > sz[hs[t]]) hs[t] = v;
sz[t]+=sz[v];
if(st.size() - cur_p >= block_size){
//get one block
while(st.size() > cur_p){
block_num[st.top()] = cur_block;
st.pop();
}
cur_block++;
}
}
if(t==1){
while(!st.empty()){
block_num[st.top()] = cur_block;
st.pop();
}
cur_block++;
}
}
void dfs2(int t, int f){
top[t] = f;
if(!hs[t]) return;
dfs2(hs[t],f);
for(int v : g[t]){
if(v==fa[t] || v == hs[t])continue;
dfs2(v,v);
}
}
int lca(int x, int y){
if(dp[x] < dp[y]) swap(x,y);
while(top[x] != top[y]){
if(dp[top[x]] >dp[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
return dp[x] < dp[y]? x:y;
}
void del(int idx){
sum -= 1ll*w[num[idx]] * v[idx];
num[idx]--;
}
void add(int idx){
num[idx]++;
sum += 1ll*w[num[idx]] * v[idx];
}
void change(int pos, int old_c, int new_c){
if(vis[pos]){
del(old_c);
add(new_c);
}
c[pos] = new_c;
}
void change(int idx){
if(vis[idx]){
del(c[idx]);
}
else {
add(c[idx]);
}
vis[idx]^=1;
}
void move(int old_x, int new_x){
if(dp[old_x] < dp[new_x]) swap(old_x,new_x);
while(dp[old_x]!=dp[new_x]){
change(old_x);
old_x = fa[old_x];
}
while(old_x != new_x){
change(old_x);
change(new_x);
old_x = fa[old_x];
new_x = fa[new_x];
}
}
int main(){
int n,m,Q;
scanf("%d%d%d",&n,&m,&Q);
block_size = pow(n,2.0/3);
clk = 0;
fr(i,1,m+1) {
sf("%d",&v[i]);
}
fr(i,1,n+1) {
sf("%d",&w[i]);
}
fr(i,0,n-1){
int u,v;
sf("%d%d",&u,&v);
g[u].pb(v);
g[v].pb(u);
}
fr(i,1,n+1){
sf("%d",&c[i]);
old[i] = c[i];
}
dfs1(1,1);
dfs2(1,1);
vector<query> q;
vector<Revise> modi;
fr(i,0,Q){
int t,x,y;
sf("%d%d%d",&t,&x,&y);
if(t==1){
query qy;
qy.l = x;qy.r = y;
qy.id = q.size();
qy.block = block_num[qy.l];
qy.r_block = block_num[qy.r];
qy.t = modi.size()-1;
q.pb(qy);
}
else {
Revise r;
r.pos = x;r.c = y;
r.old_c = old[r.pos];
modi.pb(r);
old[r.pos] = r.c;
}
}
sort(q.begin(),q.end());
sum = 0;
int l = 1, r = 1;
change(l);
int cur = -1;
for(int i = 0; i < q.size();++i){
if(dfn[q[i].l]>dfn[q[i].r]) swap(q[i].l,q[i].r);
int f = lca(l,r);
change(f);
while(cur<q[i].t){
cur++;
int pos = modi[cur].pos;
int old_c = modi[cur].old_c;
int new_c = modi[cur].c;
change(pos, old_c, new_c);
}
while(cur>q[i].t){
int pos = modi[cur].pos;
int old_c = modi[cur].old_c;
int new_c = modi[cur].c;
change(pos, new_c,old_c);
cur--;
}
if(l!=q[i].l) {
move(l,q[i].l);
l = q[i].l;
}
if(r!=q[i].r){
move(r,q[i].r);
r = q[i].r;
}
f = lca(l,r);
change(f);
ans[q[i].id] = sum;
}
fr(i,0,q.size()){
printf("%lld\n",ans[i]);
}
}