void init(int n) {
for(int i = 1; i <= n; i++) {
fa[i] = i;
dist[i] = 0;
size[i] = 1;
}
}
int get(int x) {
if(fa[x] == x) {
return x;
}
int y = fa[x];
fa[x] = get(y);
dist[x] += dist[y];
return fa[x];
}
void merge(int a, int b) {
a = get(a);
b = get(b);
if (a != b) {
fa[a] = b;
dist[a] = size[b];
size[b] += size[a];
}
}
for(int i = 1; i <= n; i++) {
fa[i] = i;
dist[i] = 0;
size[i] = 1;
}
}
int get(int x) {
if(fa[x] == x) {
return x;
}
int y = fa[x];
fa[x] = get(y);
dist[x] += dist[y];
return fa[x];
}
void merge(int a, int b) {
a = get(a);
b = get(b);
if (a != b) {
fa[a] = b;
dist[a] = size[b];
size[b] += size[a];
}
}