题目链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1405
给定一棵无根树,假设它有n个节点,节点编号从1到n, 求任意两点之间的距离(最短路径)之和。
Input
第一行包含一个正整数n (n <= 100000),表示节点个数。 后面(n - 1)行,每行两个整数表示树的边。
Output
每行一个整数,第i(i = 1,2,...n)行表示所有节点到第i个点的距离之和。
Input示例
4 1 2 3 2 4 2
Output示例
5 3 5 5
题解:
数学题,也可以说dp,不太难。
(1)我们给树规定一个根。假设所有节点编号是0-(n-1),我们可以简单地把0当作根,这样下来父子关系就确定了。
(2)定义数组num[x]表示以节点x为根的子树有多少个节点,dp[x]是我们所求的——所有节点到节点x的距离之和。
(3)在步骤(1)中,其实我们同时可以计算出 num[x],还可以计算出每个节点的深度(每个到根节点0的距离),累加全部节点深度得到的其实就是是dp[0]。
(4) 假设一个非根节点x,它的父亲节点是y, 并且dp[y]已经计算好了,我们如何计算dp[x]?
以x为根的子树中那些节点,到x的距离比到y的距离少1, 这样的节点有num[x]个。
其余节点到x的距离比到y的距离多1,这样的节点有(n - num[x])个。
于是我们有 dp[x] = dp[y] - num[x] + (n - num[x])
= dp[y] + n - num[x] * 2
因为树的根节点dp[0]在步骤(3)已经计算出来了,根据所有的父子关系和这个上式,我们可以按照顺序计算出整个dp数组。
注意点: 重要的步骤都是简单的dfs,但是一半递归实现可能导致堆栈溢出。
自己补充: 非根节点也满足上述递归公式。
num[x]是包含本身的子树节点数目
一开始写得很挫,不过AC了。
<span style="font-size:24px;">#include"stdio.h"
#include"stdlib.h"
#include"algorithm"
#include"string.h"
#include"vector"
using namespace std;
#pragma comment(linker, "/STACK:10240000,10240000")
//对于一些因为递归太深,导致爆栈的程序,可以使用扩栈语句,但仅限VC编译
const int maxn=1e5+5;
vector<int>G[maxn];
int num[maxn]; //num[i] 统计i的孩子数
long long dp[maxn]; //dp[i]:所有节点到i的最短距离之和
int n;
void ReadTree()
{
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
int s,e;
scanf("%d%d",&s,&e);
G[s].push_back(e);
G[e].push_back(s);
}
}
void dfs(int root,int fa)
{//求dp[root]
num[root]=1;
dp[root]=0;
int d=G[root].size();
for(int i=0;i<d;i++)
{
int k=G[root][i];
if(k!=fa)
{
dfs(k,root);
num[root]+=num[k];
dp[root]+=num[k]-1+dp[k]+1;
}
}
}
void Print()
{
for(int i=1;i<=n;i++)
printf("%lld\n",dp[i]);
}
void Solve_dfs(int root,int fa)
{
int d=G[root].size();
for(int i=0;i<d;i++)
{
int k=G[root][i];
if(k!=fa)
{
if(num[k]==0) //叶子
dp[k]=dp[root]+n-2;
else
dp[k]=dp[root]+n-2*num[k];
Solve_dfs(k,root);
}
}
//printf("%d*******\n",root);
//Print();
}
int main()
{
ReadTree();
dfs(1,-1);
//printf("%d *\n",dp[1]);
//for(int i=1;i<=n;i++)
// printf("%d ",num[i]);
// puts("\n");
Solve_dfs(1,-1);
Print();
return 0;
}</span>
看了别人代码后改进:
<span style="font-size:24px;">#include"stdio.h"
#include"string.h"
#include"vector"
using namespace std;
#pragma comment(linker, "/STACK:10240000,10240000")
//对于一些因为递归太深,导致爆栈的程序,可以使用扩栈语句,但仅限VC编译
const int maxn=1e5+5;
vector<int>G[maxn];
int num[maxn]; //num[i] 统计i的子树节点
long long dp[maxn]; //dp[i]:所有节点到i的最短距离之和
int n;
void Print()
{
for(int i=1;i<=n;i++)
printf("%lld\n",dp[i]);
}
void ReadTree()
{
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
int s,e;
scanf("%d%d",&s,&e);
G[s].push_back(e);
G[e].push_back(s);
}
}
void dfs(int root,int fa,int floor)
{//求dp[1]
num[root]=1;
dp[1]+=floor;
int d=G[root].size();
for(int i=0;i<d;i++)
{
int k=G[root][i];
if(k!=fa)
{
dfs(k,root,floor+1);
num[root]+=num[k];
}
}
}
void Solve_dfs(int root,int fa)
{
int d=G[root].size();
for(int i=0;i<d;i++)
{
int k=G[root][i];
if(k!=fa)
{
dp[k]=dp[root]+n-2*num[k];
Solve_dfs(k,root);
}
}
}
int main()
{
ReadTree();
dfs(1,-1,0);
Solve_dfs(1,-1);
Print();
return 0;
} </span>
别人代码,感觉写得很好(完全跟着题解思路步骤)值得学习。
<span style="font-size:24px;">#pragma comment(linker, "/STACK:102400000,102400000")
#include <stdio.h>
#include <stdlib.h>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#include <queue>
#include <set>
#include <functional>
using namespace std;
#define PINF 100000000
#define NINF -100000000
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define FOR(i, f, e) for(int i = f; i < e; i++)
typedef __int64 ll;
#define maxn 100000
vector<int> G[maxn];//连边
int subTree[maxn];
ll dp[maxn];//子节点个数和每个节点的距离和
int N;//节点总个数
bool vis[maxn];
int dfs1(int u, int k){//返回点u的子节点个数(包括u自己),顺便计算dp[0]
subTree[u] = vis[u] = 1;
dp[0] += (ll)k;
for (unsigned i = 0; i < G[u].size(); i++){
if (!vis[G[u][i]])
subTree[u] += dfs1(G[u][i], k + 1);
}
return subTree[u];
}
void dfs2(int u, int fa){//计算dp
vis[u] = 1;
if(u != 0) dp[u] = dp[fa] - subTree[u] * 2 + N;
for (unsigned i = 0; i < G[u].size(); i++){
if (!vis[G[u][i]]) dfs2(G[u][i], u);
}
}
int main(){
scanf("%d", &N);
FOR(i, 1, N){
int a, b;
scanf("%d%d", &a, &b);
a--; b--;
G[a].push_back(b);
G[b].push_back(a);
}
dfs1(0, 0);
memset(vis, false, sizeof(vis));
dfs2(0, 0);
FOR(i, 0, N){
printf("%lld\n", dp[i]);
}
}</span>