一:……noip2016,感慨万千,勉强一等。
二:来看看running这题,题面略。其实除了树上差分以外,还有没有方法可以解决这题呢?其实是有的,比如树链剖分+离线标记。之前我们就分析过,我们要对每一个i求i的子树中有多少个值为A[i]的dep[u](这里不考虑B[i],因为做法一样)。我们可以利用树链剖分,把u->Lca(u,v)这一段分成不超过log(n)段,而每一段在Dfs序中都是连续的。也就是说,我们要对这log(n)段都打上值为dep[u]的标记。
考虑到本题是先打很多个标记,在最后查询一遍求ans[i],故可以用离线算法,在每一个连续的Dfs序(L,R)的L打一个值为dep[u]的标记,在R+1的地方删除它。最后再一遍按Dfs序从左往右扫,加减标记就行了。
由于一条链需要打4*log(n)条标记,最后扫一遍需要将所有的标记加进来,故时间复杂度为O(m*log(n)),如果用邻接链表储存所有标记的话,这部分的空间是4*m*log(n),不会超时,炸空间。
其实本人在考场上想过用树剖,但以为树剖一定是n*log^2(n),而数据是3*10^5,故没往那一块儿想。本题的树剖因为没有用线段树操作,所以时间少乘了一个log(n)。
现在总结一下,目测所有树上差分的题都可以用树剖做,因为对一个点打标记,影响的将会是它到root的这一段,就是一条链。但单纯的树上差分,且只查询一个值,时间是o(n)的,一遍Dfs就行,而树剖难写,又要多一个log(n)……
CODE(然而很遗憾,洛谷上测只有95,会超一个点,常数太大……):
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<set>
using namespace std;
const int maxn=300100;
const int maxl=20;
struct data
{
int obj,_Next;
} e[maxn<<1];
int head[maxn];
int cur=-1;
struct data1
{
int id,val,num,_Next1;
} e1[maxn*maxl<<2];
int head1[maxn];
int cur1=-1;
int fa[maxn][maxl];
int dep[maxn];
int _Size[maxn];
int _Son[maxn];
int w[maxn];
int _Time;
int dfsx[maxn];
int top[maxn];
int que[maxn];
int he=0,ta=1;
int A[maxn];
int B[maxn];
int cntA[maxn];
int cntB[maxn<<1];
int ans[maxn];
int n,m;
void Add(int x,int y)
{
cur++;
e[cur].obj=y;
e[cur]._Next=head[x];
head[x]=cur;
}
void Bfs1()
{
que[1]=1;
fa[1][0]=1;
dep[1]=1;
while (he<ta)
{
he++;
int node=que[he];
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node][0])
{
fa[son][0]=node;
dep[son]=dep[node]+1;
ta++;
que[ta]=son;
}
p=e[p]._Next;
}
}
}
void Bfs2()
{
for (int i=1; i<=n; i++) _Size[i]=1;
for (int i=n; i>=2; i--)
{
int son=que[i];
int node=fa[son][0];
_Size[node]+=_Size[son];
if (_Size[son]>_Size[ _Son[node] ])
_Son[node]=son;
}
}
void Bfs3()
{
top[1]=1;
w[1]=1;
dfsx[1]=1;
for (int i=1; i<n; i++)
{
int node=que[i];
int heavy_son=_Son[node];
_Time=w[node]+1;
if (heavy_son!=0)
{
top[heavy_son]=top[node];
w[heavy_son]=w[node]+1;
dfsx[ w[heavy_son] ]=heavy_son;
_Time+=_Size[heavy_son];
}
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( son!=heavy_son && son!=fa[node][0] )
{
top[son]=son;
w[son]=_Time;
dfsx[ w[son] ]=son;
_Time+=_Size[son];
}
p=e[p]._Next;
}
}
}
void Make_fa()
{
for (int j=1; j<maxl; j++)
for (int i=1; i<=n; i++)
fa[i][j]=fa[ fa[i][j-1] ][j-1];
}
int Lca(int u,int v)
{
if (dep[u]<dep[v]) swap(u,v);
for (int j=maxl-1; j>=0; j--)
if (dep[ fa[u][j] ]>=dep[v])
u=fa[u][j];
if (u==v) return u;
for (int j=maxl-1; j>=0; j--)
if (fa[u][j]!=fa[v][j])
{
u=fa[u][j];
v=fa[v][j];
}
return fa[u][0];
}
void Add1(int x,int nid,int nval,int nnum)
{
cur1++;
e1[cur1].id=nid;
e1[cur1].val=nval;
e1[cur1].num=nnum;
e1[cur1]._Next1=head1[x];
head1[x]=cur1;
}
void Plus(int u,int v,int nid,int nval)
{
if (top[u]==top[v])
{
if (w[u]>w[v]) swap(u,v);
Add1(w[u],nid,nval,1);
Add1(w[v]+1,nid,nval,-1);
return;
}
if (dep[ top[u] ]<dep[ top[v] ]) swap(u,v);
int tu=top[u];
Add1(w[tu],nid,nval,1);
Add1(w[u]+1,nid,nval,-1);
Plus(fa[tu][0],v,nid,nval);
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
Add(a,b);
Add(b,a);
}
Bfs1();
Bfs2();
Bfs3();
Make_fa();
for (int i=1; i<=n; i++)
{
int W;
scanf("%d",&W);
A[i]=dep[i]+W;
B[i]=dep[i]-W;
}
for (int i=1; i<=n+1; i++) head1[i]=-1;
for (int i=1; i<=m; i++)
{
int u,v;
scanf("%d%d",&u,&v);
int lca=Lca(u,v);
int len=dep[u]+dep[v]-2*dep[lca];
Plus(u,lca,0,dep[u]);
Plus(v,lca,1,dep[v]-len);
if (A[lca]==dep[u]) ans[lca]--;
}
for (int i=1; i<=n; i++)
{
int p=head1[i];
while (p!=-1)
{
int nid=e1[p].id;
int nval=e1[p].val;
int nnum=e1[p].num;
if (nid==0) cntA[nval]+=nnum;
else cntB[nval+maxn]+=nnum;
p=e1[p]._Next1;
}
int node=dfsx[i];
ans[node]+=cntA[ A[node] ];
ans[node]+=cntB[ B[node]+maxn ];
}
for (int i=1; i<=n; i++) printf("%d ",ans[i]);
printf("\n");
return 0;
}
三:接下来想讲一下关于水分的事情……
主要讲两个部分:一条链的和T=1的。一条链的该怎么做呢?我们可以把路径分为两种情况:S<=T和S>T。我们只讨论前者,因为后者是一样的。
如果s能够更新u,那么必定满足s+w[u]=u,且t>=u。那么,对于每一个节点i,我们只需要查看以(i-w[i])为S的节点有多少个T>=i就可以了。对于每一个i,我们都开一个数组记录一下它的T,排个序,然后每一次找的时候二分就好了。
现在我们来证明一下时间复杂度。预处理的时间为a1*log(a1)+a2*log(a2)……an*log(an)(ai是s=i的路径的个数,a1+a2+……an=m)<=a1*log(m)+a2*log(m)+……an*log(m)<=m*log(m)。
查询的时间也可类似地证明不会超过n*log(m),故时间复杂度为O(n*log(m)+m*log(m))。
那么空间呢?每一个点i都开一个数组,空间不是O(n^2)的吗?我们可以把所有的小数组合并成一个大数组,然后记录一下每一个小数组的始末位置,这样数组的空间就是O(m)的了。如图:
好了,接下来讲一讲T=1的时候……关于这个,我们需要恶补一下multiset的用法。multiset是一个stl库中存储可重复元素的集合,内部的存储形式是一棵平衡树,故它的大部分函数都是log(n)的,其中n是集合中元素的个数:
multiset<int,Comp> s;//开一个元素为int,内部排序法则为Comp(没有则默认小的在前)的multiset。
s.insert(x);//在s中插入一个元素x。
s.count(x)//返回s中元素x的个数。
s.empty();//s为空时返回真。
s.erase(x);//删除s中所有值为x的元素。
s.clear();//清空所有元素。
multiset<int> :: iterator i;//定义一个迭代器,你可以理解为是存储int类型的multiset专用的指针。
s.begin();//返回s中第一个元素的迭代器(地址)。
s.end();//返回s中最后一个元素的迭代器(地址)。
例如,令i= s.begin();则*i就是s中的第一个元素。
……
有了这些我们就可以做最简单的multiset应用了。
对于T=1,我们需要快速查询对于每一个i,它的子树里有多少个值为dep[i]+w[i]的标记。我们可以先Dfs点i的子树,然后把它子树的multiset合并起来,变成i的multiset,查询答案即可。我们用启发式合并(即每一次把小的合并到大的)。这样,一个元素只会被转移log(n)次,一次转移需要log(n)的时间,故时间复杂度为O(n*log^2(n))。由于一个元素只会在log(n)个multiset中出现,故空间复杂度为n*log(n)。
看到n=10^5,我一开始还是蛮有信心将常数些小一点过那四个点的,结果因为调用stl库,超级慢,一个点要10s左右QAQ……
附CODE(用10s骗那20分……):
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<set>
//#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int maxn=100000;
struct data
{
int obj,_Next;
} e[maxn<<1];
int head[maxn];
int cur=-1;
multiset <int> s[maxn];
int dep[maxn];
int w[maxn];
int fa[maxn];
int cnt[maxn];
int id[maxn];
int que[maxn];
int he=0,ta=1;
int ans[maxn];
int n,m;
void Add(int x,int y)
{
cur++;
e[cur].obj=y;
e[cur]._Next=head[x];
head[x]=cur;
}
void Bfs1()
{
que[1]=1;
while (he<ta)
{
he++;
int node=que[he];
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node])
{
fa[son]=node;
dep[son]=dep[node]+1;
ta++;
que[ta]=son;
}
p=e[p]._Next;
}
}
}
void Up(int node,int son)
{
int x=id[node];
int y=id[son];
if ( s[x].size()<s[y].size() )
{
swap(id[node],id[son]);
swap(x,y);
}
multiset <int>::iterator S=s[y].begin();
multiset <int>::iterator T=s[y].end();
while (S!=T)
{
int z=*S;
int num=s[y].count(z);
for (int i=1; i<=num; i++)
{
s[x].insert(z);
S++;
}
}
}
bool Comp(int x,int y)
{
return dep[x]>dep[y];
}
void Bfs2()
{
for (int i=1; i<=n; i++) que[i]=i;
sort(que+1,que+n+1,Comp);
for (int i=1; i<=n; i++)
for (int j=1; j<=cnt[i]; j++) s[i].insert(dep[i]);
for (int i=1; i<=n; i++) id[i]=i;
for (int i=1; i<n; i++)
{
int son=que[i];
int node=fa[son];
ans[son]=s[ id[son] ].count(dep[son]+w[son]);
Up(node,son);
}
ans[1]=s[ id[1] ].count(dep[1]+w[1]);
}
int main()
{
freopen("running.in","r",stdin);
freopen("running.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
Add(a,b);
Add(b,a);
}
for (int i=1; i<=n; i++) scanf("%d",&w[i]);
fa[1]=1;
dep[1]=0;
Bfs1();
bool fT=true;
for (int i=1; i<=m; i++)
{
int S,T;
scanf("%d%d",&S,&T);
if (T!=1)
{
fT=false;
break;
}
cnt[S]++;
}
if (!fT) return 0;
Bfs2();
for (int i=1; i<=n; i++) printf("%d ",ans[i]);
printf("\n");
return 0;
}