题解:
#include <bits/stdc++.h>
#define forn(i, n) for (int i = 0; i < int(n); i++)
using namespace std;
int n;
vector<int> a;
vector<vector<int>> g;
long long ans;
vector<map<int, int>> cnt;
void dfs(int v, int p = -1){
int bst = -1;
for (int u : g[v]) if (u != p){
dfs(u, v);
if (bst == -1 || cnt[bst].size() < cnt[u].size())
bst = u;
}
for (int u : g[v]) if (u != p && u !