总算是自己写过了这道所谓的中等题。
原题:Garland
题意:给定一个n个点的树以及每个点的价值。问能否去掉两条边使得剩下的三个部分价值和相等,若可以输出要切掉的那条边所连接的子树根节点, 不可以则输出-1。
首先,如果温度总和不能除以3就输出-1。
对于答案,有两种可能,第一种是在母体切出两个sum/3的子体,第二种是切出一个sum/3*2的子体,再在这个子体切出sum/3的子体
这里使用链式前向星建树,建树时用fa数组保存父结点。
然后用dfs计算以每个点i**为root时整棵子树的温度总和**,记作treetem[i],如果总和是sum/3,就用set把这个结点记录下来。
这里不能直接确定第一种情况,因为温度可能为负为零,所以一个treetem为sum/3的结点可能是另一个treetem为sum/3的结点的父结点,就和预料中的第一种分法冲突了。
所以要先看第二种分法,对于每个ans_1_3(treetem为sum/3的结点),用之前记录的fa数组回溯,如果有个ans_2_3(treetem为sum/3*2的结点),就是符合第二种分法的答案了,这个时候可以顺便对第一种方法进行排雷,删掉那些本文上一段描述的点。
如果每个点往上都要回溯到root,时间复杂度太高,所以这里可以用vis数组来优化。
如果第二种分法找不到答案,再考虑第一种。此时的set内已经没有了雷,所以只要set有两个以上结点,就可以随便找两个当作答案输出了。
代码如下:
#include<iostream>
#include<cstdio>
#include<cmath>
#include<string>
#include<cstring>
#include<algorithm>
#include<set>
#include<map>
#include<list>
#include<vector>
#include<stack>
#include<queue>
#include<functional>
#define D long long
#define F double
#define MAX 0x7fffffff
#define MIN -0x7fffffff
#define mmm(a,b) memset(a,b,sizeof(a))
#define pb push_back
#define mk make_pair
#define fi first
#define se second
#define pill pair<int, int>
using namespace std;
#define N 1001000
#define MOD ((int)1e9+7)
set<int>ans_1_3;//sum/3的点
set<int>::iterator it,itt;
int n;int head[N];
D sum;
struct edge{
int to,nex;
}e[4*N];int now=0;
void add(int a,int b){
e[++now].to=b;e[now].nex=head[a];head[a]=now;//1~now
}
int tem[N];//灯的温度
int fa[N];//父灯
int root;//根
int treetem[N];//生出子树的总温度
int vis[N];//对ans_1_3的回溯时的访问标记
void dfs(int root){
treetem[root]+=tem[root];
for(int i=head[root];~i;i=e[i].nex){
dfs(e[i].to);
treetem[root]+=treetem[e[i].to];
}
if(treetem[root]==sum/3&&fa[root]!=0)ans_1_3.insert(root);//第一次少了 fa[root]!=0 就WA在了案例33
//因为sum可能是0,这时少了这句就把root结点也加进去了
}
int main()
{sum=0;mmm(head,-1);mmm(tem,0);mmm(treetem,0);mmm(vis,0);
cin>>n;
for(int i=1;i<=n;i++){//第一行输入2,则2指向1,fa[1]=2;
scanf("%d%d",&fa[i],&tem[i]);sum+=tem[i];
add(fa[i],i);
if(fa[i]==0)root=i;
}
if(sum%3!=0){cout<<-1<<endl;return 0;}
dfs(root);
for(it=ans_1_3.begin();it!=ans_1_3.end();it++){
int mid=*it;
while(fa[mid]!=root){
mid=fa[mid];if(vis[mid])break;vis[mid]=1;
if(treetem[mid]==sum/3*2){
cout<<*it<<' '<<mid<<endl;return 0;
}
if(treetem[mid]==sum/3){
ans_1_3.erase(mid);
}
}
}
if(ans_1_3.size()<2){cout<<-1<<endl;return 0;}
itt=it=ans_1_3.begin();itt++;
cout<<*it<<' '<<*itt<<endl;
return 0;
}