E. Tree
time limit per test1.5 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output
You are given a tree with n nodes and q queries.
Every query starts with three integers k, m and r, followed by k nodes of the tree a1,a2,…,ak. To answer a query, assume that the tree is rooted at r. We want to divide the k given nodes into at most m groups such that the following conditions are met:
Each node should be in exactly one group and each group should have at least one node.
In any group, there should be no two distinct nodes such that one node is an ancestor (direct or indirect) of the other.
You need to output the number of ways modulo 109+7 for every query.
Input
The first line contains two integers n and q (1≤n,q≤105) — the number of vertices in the tree and the number of queries, respectively.
Each of the next n−1 lines contains two integers u and v (1≤u,v≤n,u≠v), denoting an edge connecting vertex u and vertex v. It is guaranteed that the given graph is a tree.
Each of the next q lines starts with three integers k, m and r (1≤k,r≤n, 1≤m≤min(300,k)) — the number of nodes, the maximum number of groups and the root of the tree for the current query, respectively. They are followed by k distinct integers a1,a2,…,ak (1≤ai≤n), denoting the nodes of the current query.
It is guaranteed that the sum of k over all queries does not exceed 105.
Output
Print q lines, where the i-th line contains the answer to the i-th query.
很棒的一道题!而且完全触及了我的盲区!
学长用了不到一个小时就切了这题,真的佩服得五体投地!
首先需要想到dp方程:
d
p
[
i
]
[
j
]
=
d
p
[
i
−
1
]
[
j
]
∗
(
j
−
h
[
i
]
)
+
d
p
[
i
−
1
]
[
j
−
1
]
dp[i][j]=dp[i-1][j]*(j-h[i])+dp[i-1][j-1]
dp[i][j]=dp[i−1][j]∗(j−h[i])+dp[i−1][j−1]
其中
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示前i个节点中分成j个集合的方案总数,分两种情况,一种是第i个节点在之前某个集合中,一种是自成一个新集合。
h
[
i
]
h[i]
h[i]表示的是i的祖先中属于给出节点的个数。
方程比较简单,但是需要遵循一个更新顺序。由于方程成立的条件是
h
[
i
]
h[i]
h[i]代表的节点必须已经考虑分配过,所以我们自然会想到节点按照dfs序排序,再进行dp。但是这题又一个难点“换根”就难以解决,换根之后的dfs序不太可能重新来求。
仔细观察,我们发现
h
[
i
]
h[i]
h[i]越大,深度必然越深,按照
h
[
i
]
h[i]
h[i]从小到大的顺序dp是更方便的选择,而且这样的话,“换根”问题可以得到解决,具体操作如下。
先以1为根,求出dfs序和子树大小,当询问到来时,先对每个标记点的子树区间进行区间加1(实际上可以是一个树状数组,起点+1,终点之后-1,然后求前缀和就得到了区间覆盖数,当然非要写区间线段树的点查询也OK),这样的话我们求出了
T
M
P
[
i
]
TMP[i]
TMP[i]代表以1为根时,i以及i的祖先中是标记点的个数。然后我们用LCA的方式推出以新根为根时,i的祖先中是标记点的个数:
h
[
i
]
=
T
M
P
[
a
[
i
]
]
+
T
M
P
[
r
o
o
t
]
−
2
∗
T
M
P
[
l
c
a
]
+
m
a
r
k
[
l
c
a
]
−
1
h[i]=TMP[a[i]]+TMP[root]-2*TMP[lca]+mark[lca]-1
h[i]=TMP[a[i]]+TMP[root]−2∗TMP[lca]+mark[lca]−1
这样的话求出h[i]之后,我们把h数组排序,就可以放心dp了。
题目中说明了m不会超过300,k的总和不会超过1E5,这也保证了我们dp的时间不会太久。而其余的效率都是带log的,有保障。
#include<cstdio>
#include<algorithm>
#define mo 1000000007
using namespace std;
using LL=long long;
struct Finwick
{
int n,C[100005];
inline void init(int x)
{
n=x;
}
int sum(int x)
{
int res=0;
while(x)
res+=C[x],x-=x&-x;
return res;
}
int add(int x, int d)
{
while(x<=n)
C[x]+=d,x+=x&-x;
}
}F;
int n,q,k,m,ans,rt,a[100005];
int rt_ans,h[100005],dp[305];
int dfs_clock,dfn[100005],son[100005];
int dep[100005],p[100005][18];
vector<int> E[100005];
bool mrk[100005];
void dfs(int i, int fa)
{
dep[i]=dep[fa]+1;
son[i]=1;
p[i][0]=fa;
for(int j=1;(1<<j)<dep[i];j++)
p[i][j]=p[p[i][j-1]][j-1];
dfn[i]=++dfs_clock;
for(int v:E[i])
if(v!=fa)
dfs(v,i),son[i]+=son[v];
}
int LCA(int a, int b)
{
if(dep[a]<dep[b])
swap(a,b);
int i,j,res=0;
for(i=0;(1<<i)<dep[a];i++);
i--;
for(j=i;j>=0;j--)
if(dep[a]-(1<<j)>=dep[b])
a=p[a][j];
if(a==b)
return a;
for(j=i;j>=0;j--)
if(p[a][j]!=p[b][j])
a=p[a][j],b=p[b][j];
return p[a][0];
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1,u,v;i<n;i++)
scanf("%d%d",&u,&v),E[u].push_back(v),E[v].push_back(u);
dfs(1,0);
while(q--)
{
scanf("%d%d%d",&k,&m,&rt);
F.init(n);
for(int i=1;i<=k;i++) //添加区间和标记
{
scanf("%d",&a[i]);
mrk[a[i]]=true;
F.add(dfn[a[i]],1);
F.add(dfn[a[i]]+son[a[i]],-1);
}
rt_ans=F.sum(dfn[rt]);
for(int i=1,lca;i<=k;i++)
{
lca=LCA(a[i],rt);
h[i]=F.sum(dfn[a[i]])+rt_ans-2*F.sum(dfn[lca])+mrk[lca]-1;
}
sort(h+1,h+k+1);
for(int i=1;i<=k;i++) //擦除区间和标记
{
F.add(dfn[a[i]],-1);
F.add(dfn[a[i]]+son[a[i]],1);
mrk[a[i]]=false;
}
for(int i=0;i<=m;i++)
dp[i]=0;
dp[0]=1;
for(int i=1;i<=k;i++)
for(int j=min(i,m);j>=0;j--)
if(j>h[i])
dp[j]=((LL)dp[j]*(j-h[i])%mo+dp[j-1])%mo;
else
dp[j]=0;
ans=0;
for(int i=1;i<=m;i++)
ans=(ans+dp[i])%mo;
printf("%d\n",ans);
}
return 0;
}