Description
给出一棵nn个节点的树,节点点权为cici,mm次查询,每次查询从节点到tt节点的简单路径上,权值在区间之间的点的权值和
Input
多组用例,每组用例首先输入两个整数n,mn,m表示点数和查询数,之后输入nn个整数表示ii节点的点权,之后行每行输入两个整数u,vu,v表示一条树边,最后mm行每行输入四个整数表示一次查询(1≤n,m≤105,1≤ci≤109,1≤a≤b≤109)(1≤n,m≤105,1≤ci≤109,1≤a≤b≤109)
Output
对于每次查询,输出结果
Sample Input
5 3
1 2 1 3 2
1 2
2 4
3 1
2 5
4 5 1 3
1 1 1 1
3 5 2 3
Sample Output
7 1 4
Solution
首先树链剖分把一条树上路径转化为若干段连续的区间,问题转化为区间中权值介于[a,b][a,b]的数字之和,用线段树维护区间最值和区间和,如果区间最小值不小于aa且区间最大值不超过,说明区间所有数字都介于[a,b][a,b],对答案的贡献就是区间和,如果区间最大值小于aa或者区间最小值大于则该区间对答案没有贡献,否则把区间分成两部分再分别求解,时间复杂度O(mlog22n)O(mlog22n)
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
#define maxn 100005
struct Edge
{
int to,next;
}edge[2*maxn];
int head[maxn],tot;
int idx;//时间戳,用于得到dfs序
int size[maxn];//size[i]表示以i节点为根节点的子树中节点个数
int fa[maxn];// fa[i]表示i节点的父亲节点
int son[maxn];//son[i]表示i节点的重儿子节点(如果没有重儿子则son[i]=i)
int dep[maxn];//dep[i]表示i节点在树中的深度
int top[maxn];//top[i]表示i节点所处重链中深度最低的节点
int id[maxn];//id[i]表示i节点的dfs序
int pos[maxn];//pos[i]表示dfs序为i的节点
void init()
{
idx=tot=0;
dep[1]=fa[1]=0;//默认1节点为根节点,初始化其父亲节点为0,深度为0
memset(son,0,sizeof(son));
memset(head,-1,sizeof(head));
}
void add(int u,int v)
{
edge[tot].to=v;
edge[tot].next=head[u];
head[u]=tot++;
}
void dfs1(int u)//得到size,fa,dep,son
{
size[u]=1;
for(int i=head[u];~i;i=edge[i].next)
{
int v=edge[i].to;
if(v!=fa[u])
{
fa[v]=u;//v的父亲节点是u
dep[v]=dep[u]+1;//儿子的深度等于父亲的深度加一
dfs1(v);//深搜
size[u]+=size[v];//父亲节点为根节点的子树节点个数等于各儿子节点为根节点的子树节点个数之和加一(父亲节点本身)
if(size[son[u]]<size[v]) son[u]=v;//更新重儿子
}
}
}
void dfs2(int u,int topu)//得到top,id,pos,l,r
{
top[u]=topu;
id[u]=++idx;//得到u节点的dfs序
pos[idx]=u;//记录这个dfs序对应的节点
if(son[u]) dfs2(son[u],top[u]);//有重儿子首先深搜重儿子
for(int i=head[u];~i;i=edge[i].next)
{
int v=edge[i].to;
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);//深搜所有儿子节点
}
}
int n,m,c[maxn];
#define ls (t<<1)
#define rs ((t<<1)|1)
int Max[maxn<<2],Min[maxn<<2];
ll Sum[maxn<<2];
void push_up(int t)
{
Max[t]=max(Max[ls],Max[rs]);
Min[t]=min(Min[ls],Min[rs]);
Sum[t]=Sum[ls]+Sum[rs];
}
void build(int l,int r,int t)
{
if(l==r)
{
Max[t]=Min[t]=Sum[t]=c[pos[l]];
return ;
}
int mid=(l+r)/2;
build(l,mid,ls);build(mid+1,r,rs);
push_up(t);
}
ll query(int L,int R,int l,int r,int t,int a,int b)
{
if(L<=l&&r<=R)
{
if(Min[t]>=a&&Max[t]<=b)return Sum[t];
if(Min[t]>b||Max[t]<a||l==r)return 0;
}
int mid=(l+r)/2;
ll ans=0;
if(L<=mid)ans+=query(L,R,l,mid,ls,a,b);
if(R>mid)ans+=query(L,R,mid+1,r,rs,a,b);
return ans;
}
ll Solve(int u,int v,int a,int b)
{
ll ans=0;
int top1=top[u],top2=top[v];
while(top1!=top2)
{
if(dep[top1]<dep[top2])
{
swap(top1,top2);
swap(u,v);
}
ans+=query(id[top1],id[u],1,n,1,a,b);
u=fa[top1];
top1=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
ans+=query(id[u],id[v],1,n,1,a,b);
return ans;
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
init();
for(int i=1;i<=n;i++)scanf("%d",&c[i]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs1(1);
dfs2(1,1);
build(1,n,1);
for(int i=1;i<=m;i++)
{
int l,r,a,b;
scanf("%d%d%d%d",&l,&r,&a,&b);
printf("%I64d%c",Solve(l,r,a,b),i==m?'\n':' ');
}
}
return 0;
}