题解:点分治枚举任意两点为端点的路,判断路是否是正确的,计算重心为根的树,暴力跑子树,看子树是否有从根到子树上的点的路有错误,标记这个点,标记这颗子树,如果有多颗子树被标记,那么这整颗树都标记上,如果只有一颗子树被标记,那么除这颗子树之外的子树全部标记,然后在判断与根相连的边有没有相同的,如果相同,这两条边相连两颗子树的点都要被标记,最后没有被标记的点就是正确答案。
代码:
#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+4;
const int inf=0x3f3f3f3f;
int head[maxn],tot;
struct node{
int v,w,next;
}s[maxn*2];
int siz[maxn],maxp[maxn],vis[maxn];
int sum,root,n;
void addre(int u,int v,int w){
s[tot].v=v; s[tot].w=w; s[tot].next=head[u]; head[u]=tot++;
}
void getroot(int u,int fa){
siz[u]=1; maxp[u]=0;
for(int i=head[u];i!=-1;i=s[i].next)
{
int v=s[i].v;
if(v==fa||vis[v]) continue;
getroot(v,u);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],sum-siz[u]);
if(maxp[u]<maxp[root]) root=u;
}
int flag;
int pxx[maxn]; int anss[maxn];
void getdis(int u,int fa,int last,int ti){
for(int i=head[u];i!=-1;i=s[i].next){
int v=s[i].v,w=s[i].w;
if(v==fa||vis[v]) continue;
if(w==last||ti) {
flag=1; anss[v]=1;
getdis(v,u,w,ti|1);
}
else getdis(v,u,w,ti);
}
}
void dfs(int u,int fa){ //标记子树上的点
anss[u]=1;
for(int i=head[u];i!=-1;i=s[i].next){
int v=s[i].v;
if(v==fa||vis[v]) continue;
dfs(v,u);
}
}
void calc(int u){ // 计算
int ax=0,tag;
for(int i=head[u];i!=-1;i=s[i].next){ // 标记子树
int v=s[i].v,w=s[i].w;
if(vis[v]) continue;
flag=0;
getdis(v,u,w,0);
if(flag) ax++,tag=v;
}
if(ax) anss[u]=1;
if(ax>1) { //如果被标记的子树有两个那么整颗树都被标记上
for(int i=head[u];i!=-1;i=s[i].next){
int v=s[i].v,w=s[i].w;
if(vis[v]) continue;
dfs(v,u);
}
}else if(ax==1){ // 如果只有一个那么其他的子树全部被标记上
for(int i=head[u];i!=-1;i=s[i].next){
int v=s[i].v,w=s[i].w;
if(vis[v]) continue;
if(tag==v) continue;
dfs(v,u);
}
}
for(int i=head[u];i!=-1;i=s[i].next){ //如果两条边都相等,那么这两颗子树都要被标记
int v=s[i].v,w=s[i].w;
if(vis[v]) continue;
if(pxx[w]==0) pxx[w]=v;
else {
dfs(v,u); dfs(pxx[w],u); pxx[w]=v;
}
}
for(int i=head[u];i!=-1;i=s[i].next){
int v=s[i].v,w=s[i].w;
if(vis[v]) continue;
pxx[w]=0;
}
}
void solve(int u){
vis[u]=1;
calc(u);
for(int i=head[u];i!=-1;i=s[i].next)
{
int v=s[i].v;
if(vis[v]) continue;
sum=siz[v]; maxp[root=0]=inf;
getroot(v,0); solve(root);
}
}
int main()
{
scanf("%d",&n);
memset(head,-1,sizeof(head)); tot=0;
for(int i=1;i<=n-1;i++){
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
addre(u,v,w);
addre(v,u,w);
}
maxp[root=0]=sum=n;
getroot(1,0);
solve(root);
int cntx=0;
for(int i=1;i<=n;i++){
if(anss[i]==0) cntx++;
}
cout<<cntx<<endl;
for(int i=1;i<=n;i++){
if(anss[i]==0) printf("%d\n",i);
}
return 0;
}