1.题目描述:点击打开链接
2.解题思路:本题让我长见识了。也学到了很多新的知识:LCA,多级祖先算法。如果只是单纯地将无根树转化为有根树,找到u,v的中点,再用BFS计算中线上结点的个数,那么最终会导致TLE。本题的高效算法如下:
首先求出以1为根的树的所有结点的总个数,保存在num数组中,再利用LCA算法求出u,v的公共祖先,设为LCA。找到u,v结点的中点mid.此时规定deep[u]>deep[v],不满足就交换两数。那么分两种情况讨论:(i)如果d(u,LCA)==d(v,LCA)(d表示距离),那么答案是n-num[u]-num[v];(ii)如果d(u,LCA)!=d(v,LCA),设偏向u的中点是midU(mid--即可),那么答案是num[mid]-num[midU]。
3.代码:
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<algorithm>
#include<string>
#include<sstream>
#include<set>
#include<vector>
#include<stack>
#include<map>
#include<queue>
#include<deque>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<ctime>
#include<functional>
using namespace std;
#define N 101010+10
int num[N];
int head[N];//链表表头,head[u]表示与u相连的边序号
int tot;//边序号
int deep[N];//结点深度
int p[N][30];
struct Node
{
int to, next;
}edge[N << 1];//存放所有的边
void addedge(int from, int to)
{
edge[tot].to = to;
edge[tot].next = head[from];
head[from] = tot++;
}
void dfs(int u, int fa)//无根树转化为1为根的有根树
{
for (int i = head[u]; ~i; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa)continue;
deep[v] = deep[u] + 1;
p[v][0] = u;//直接祖先
dfs(v, u);
}
}
int lca(int a, int b)//求最近公共祖先
{
if (deep[a] < deep[b])
swap(a, b);
int d = deep[a] - deep[b];
for (int i = 0; i < 30;i++)
if (d&(1 << i))
a = p[a][i];
if (a == b)
return a;
for (int i = 29; i >= 0;i--)
if (p[a][i] != p[b][i])
{
a = p[a][i];
b = p[b][i];
}
return p[a][0];
}
int dfs2(int u, int fa)//计算u为根的子树的结点数
{
num[u] = 1;
for (int i = head[u]; ~i; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa)continue;
num[u] += dfs2(v, u);
}
return num[u];
}
int main()
{
freopen("test.txt", "r", stdin);
int n;
while (scanf("%d", &n) != EOF)
{
int u, v;
memset(head, -1, sizeof(head));
tot = 0;
for (int i = 0; i < n - 1; i++)
{
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs(1, -1);//建树
memset(num, 0, sizeof(num));
deep[1] = 0;
dfs2(1, -1);//统计每个结点的结点数
for (int j = 1; j < 30;j++)
for (int i = 1; i <= n; i++)
p[i][j] = p[p[i][j - 1]][j - 1];//i的第j层祖先等于i的第j-1层祖先的祖先
int m;
scanf("%d", &m);
while (m--)
{
scanf("%d%d", &u, &v);
if (u == v)
{
printf("%d\n", n);
continue;
}
int LCA = lca(u, v);
int d1 = deep[u] - deep[LCA];
int d2 = deep[v] - deep[LCA];
if (d1 != d2)
{
if (abs(d1 - d2) & 1)//距离差是奇数,输出0
printf("0\n");
else
{
if (deep[u] < deep[v])
swap(u, v);
int dist = (d1 + d2) / 2;
int uu = u;
for (int k = 0; k < 30;k++)
if (dist&(1 << k))
uu = p[uu][k];
dist--;
int vv = u;
for (int k = 0; k < 30;k++)
if (dist&(1 << k))
vv = p[vv][k];
printf("%d\n", num[uu] - num[vv]);
}
}
else
{
int uu = u;
int dist = d1 - 1;
for (int k = 0; k < 30;k++)
if (dist&(1 << k))
uu = p[uu][k];
int vv = v;
for (int k = 0; k < 30;k++)
if (dist&(1 << k))
vv = p[vv][k];
printf("%d\n", n - num[vv] - num[uu]);
}
}
}
return 0;
}