原理
我们知道,并查集可以维护两个点是否是在同一个组内,那么带权并查集不仅有并查集的特征,还有任意集合内任意两点之间的关系,其中最为典型的是一维数轴相对距离模型。
一维数轴相对距离模型
a距离b为100,代表a在b的右边100个单位距离,c距离d为-20,代表c在d的左边20个单位的距离,现在,我们并不知道a与c,a与d点之间的距离关系。根据以上给出的距离关系,我们可以得知:dist[a]=100,dist[b]=0,dist[c]=-20,dist[d]=0,其中dist[xx]代表某个值距离当前组头节点的距离值,正数代表在头节点右边,0代表该节点为改组的头节点,负数代表在头节点的左边。现在给出a距离c为-60,那么就可以得出四个点之间的关系了,如下图:
现在a,b,c,d为同一组的点,如何去维护正确的关系呢?例如:b和d的实际距离是多少?可以用-20-100+(-60) 得到实际的值:-180,实际上确实如此。现在a的距离100为错的,怎么办呢?当进行寻找头节点是进行路径压缩是会自动修正。当b变成-180时,由于a的头节点是b,当进行寻找头部的过程中,会进行100+(-180) == -80,为什么?因为找头节点的过程就是dist进行相加。那么上面的-20-100+(-60)可以抽象为dist[后]-dist[前]+v.
- 点权代表当前节点到组头节点的距离,并不保证实时正确,可以经过find过程修正正确
- void union(l,r,v),l,r属于两个集合,l和r的距离为v,合并两个集合,find(l)的头为lf,find(r)的头为rf,find的过程会修正dist[r]和dist[l]
- find(i):寻找i所在组的头,同时修正dist[i]的值,路径压缩之前i的头为t,路径压缩后,dist[i]+=dist[t]
- quert(l,r)查询l和r的关系,find(l)==find(r),即在同一组,才有距离关系,距离=dist[l]-dist[r].
题目
模板题
可以经过如下的转化:将线段的距离转换为点的距离,就和上面的相同了
c++代码如下:
const int MAXN = 100002;
const long long INF = 9223372036854775807LL;
int n, m, q;
int father[MAXN];
long long dist[MAXN];
void prepare() {
for (int i = 0; i <= n; ++i) {
father[i] = i;
dist[i] = 0;
}
}
int find(int i) {
if (i != father[i]) {
int tmp = father[i];
father[i] = find(tmp);
dist[i] += dist[tmp];
}
return father[i];
}
void myUnion(int l, int r, long long v) {
int lf = find(l), rf = find(r);
if (lf != rf) {
father[lf] = rf;
dist[lf] = v + dist[r] - dist[l];
}
}
long long query(int l, int r) {
if (find(l) != find(r)) {
return INF;
}
return dist[l] - dist[r];
}
int main() {
scanf("%d", &n);
n += 1;
scanf("%d", &m);
scanf("%d", &q);
prepare();
int l, r;
long long v;
for (int i = 1; i <= m; ++i) {
scanf("%d", &l);
scanf("%d", &r);
r += 1;
scanf("%lld", &v);
myUnion(l, r, v);
}
for (int i = 1; i <= q; ++i) {
scanf("%d", &l);
scanf("%d", &r);
r += 1;
v = query(l, r);
if (v == INF) {
printf("UNKNOWN\n");
}
else {
printf("%lld\n", v);
}
}
return 0;
}
这里的find要重点看一下:如果要进行节点的修正,首先记录下老的头节点,在进行更新头节点,最后进行值的更新。
模板题
与题目1不同的是需要判断错误的数据,方法是如果新来的一组数据不是一个组的,进行合并,如果新来的数据在同一个组而且这两个点的距离不同,那么这个数据就是错误的数据。
代码如下:
const int MAXN = 102;
int t, n, m;
bool ans;
int father[MAXN];
int dist[MAXN];
void prepare() {
ans = true;
for (int i = 1; i <= n; i++) {
father[i] = i;
dist[i] = 0;
}
}
int find(int i) {
if (i != father[i]) {
int tmp = father[i];
father[i] = find(tmp);
dist[i] += dist[tmp];
}
return father[i];
}
void myUnion(int l, int r, int v) {
int lf = find(l), rf = find(r);
if (lf != rf) {
father[lf] = rf;
dist[lf] = v + dist[r] - dist[l];
}
}
bool check(int l, int r, int v) {
if (find(l) == find(r)) {
if ((dist[l] - dist[r] != v)) {
return false;
}
}
return true;
}
int main() {
scanf("%d", &t);
for (int c = 1; c <= t; ++c) {
scanf("%d", &n);
n += 1;
scanf("%d", &m);
prepare();
ans = true;
for (int i = 1, l, r, v; i <= m; ++i) {
scanf("%d", &l);
scanf("%d", &r);
r += 1;
scanf("%lld", &v);
if (!check(l, r, v)) {
ans = false;
}
else {
myUnion(l, r, v);
}
}
printf("%s\n", ans ? "true" : "false");
}
return 0;
}
维护距离关系
https://www.luogu.com.cn/problem/P1196
其中M l r:合并l队伍和r队伍,将l队伍整体移到r后面,如果是同一个队伍,不进行任何操作,C l r:如果l和r不在一个队伍,打印-1,如果是,打印它们相隔几艘战舰。
对与这个题来说,我们可以将dist距离改为某艘战舰前面有几艘战舰,当进行合并时,l队伍的头节点(数值为0,头节点前面没有任何战舰)加上r队伍的战舰数量,l队伍的其他节点可以在find的过程进行修正,对与查询过程:a和b中间相隔战舰数=dist[a]-dist[b]的绝对值-1.代码如下:
#include <cstdio>
#include <cstdlib>
const int MAXN = 30001;
int n = 30000;
int father[MAXN];
int dist[MAXN];
int size[MAXN];
int stack[MAXN];
void prepare() {
for (int i = 1; i <= n; ++i) {
father[i] = i;
dist[i] = 0;
size[i] = 1;
}
}
int find(int i) {
int si = 0;
while (i != father[i]) {
stack[++si] = i;
i = father[i];
}
stack[si + 1] = i;
for (int j = si; j >= 1; --j) {
father[stack[j]] = i;
dist[stack[j]] += dist[stack[j + 1]];
}
return i;
}
void myUnion(int l, int r) {
int lf = find(l), rf = find(r);
if (lf != rf) {
father[lf] = rf;
dist[lf] += size[rf];
size[rf] += size[lf];
}
}
int query(int l, int r) {
if (find(l) != find(r)) {
return -1;
}
return abs(dist[l] - dist[r]) - 1;
}
int main() {
prepare();
int t;
scanf("%d", &t);
char op[2];
for (int i = 1; i <= t; ++i) {
scanf("%s", op);
int l, r;
scanf("%d %d", &l, &r);
if (op[0] == 'M') {
myUnion(l, r);
}
else {
printf("%d\n", query(l, r));
}
}
return 0;
}
维护倍数关系
将模板题的加减变为乘除,即dist[后]/dist[前]*v,find的过程dist相关计算也变为相乘。
class Solution {
public:
vector<double> calcEquation(vector<vector<string>>& equations, vector<double>& values, vector<vector<string>>& queries) {
prepare(equations);
for (int i = 0; i < values.size(); ++i) {
union_(equations[i][0], equations[i][1], values[i]);
}
vector<double> ans(queries.size());
for (int i = 0; i < queries.size(); ++i) {
ans[i] = query(queries[i][0], queries[i][1]);
}
return ans;
}
private:
unordered_map<string, string> father;
unordered_map<string, double> dist;
void prepare(vector<vector<string>>& equations) {
father.clear();
dist.clear();
for (auto& list : equations) {
for (auto& key : list) {
father[key] = key;
dist[key] = 1.0;
}
}
}
string find(string x) {
if (father.find(x) == father.end()) {
return "";
}
string tmp, fa = x;
if (x != father[x]) {
tmp = father[x];
fa = find(tmp);
dist[x] *= dist[tmp];
father[x] = fa;
}
return fa;
}
void union_(string l, string r, double v) {
string lf = find(l), rf = find(r);
if (lf != rf) {
father[lf] = rf;
dist[lf] = dist[r] / dist[l] * v;
}
}
double query(string l, string r) {
string lf = find(l), rf = find(r);
if (lf.empty() || rf.empty() || lf != rf) {
return -1.0;
}
return dist[l] / dist[r];
}
};
维护种类关系
[NOI2001] 食物链 - 洛谷https://www.luogu.com.cn/problem/P2024
对与后两种情况很容易进行判断,重要的时第一种情况假话的判断,如何判断当前的话和之前的真话有冲突。
dist[a]=0代表a和头节点时同类,dist[b]=1代表b吃头节点,dist[c]=2代表c被头节点吃:
由于可能是负数,式子修正为(dist[后]-dist[前]+v+3)%3,怎么来的?由于两者之间的关系是在0,1,2进行切换的,因此可能有%3的这个关系,由于前面有这样的关系,因此猜测。
如果a和b不是在一个集合中,先进行合并,然后根据关系判断是否为同类。如果a和b在一个集合并且对与头节点有相同的关系,那么一定是同类。
判断a是否吃b,如果不在一个集合中 ,先根据关系进行合并。本段()以内数字都是相对于头的,情况1,a吃头,b和头是同类(a==1,b==0);情况2,a和头是同类,头吃b(a==0,b==2);情况3:a被头吃,b吃头(a==2,b==1),可以总结为(dist[a]-dist[b]+3)%3==1,a吃b是真话
代码如下:
#include <cstdio>
#define MAXN 50001
int n, k, ans;
int father[MAXN];
int dist[MAXN];
void prepare() {
ans = 0;
for (int i = 1; i <= n; ++i) {
father[i] = i;
dist[i] = 0;
}
}
int find(int i) {
if (i != father[i]) {
int tmp = father[i];
father[i] = find(tmp);
dist[i] = (dist[i] + dist[tmp]) % 3;
}
return father[i];
}
void union_(int op, int l, int r) {
int lf = find(l), rf = find(r), v = op == 1 ? 0 : 1;
if (lf != rf) {
father[lf] = rf;
dist[lf] = (dist[r] - dist[l] + v + 3) % 3;
}
}
// op == 1, 1 l r,l和r是同类
// op == 2, 2 l r,l吃r
bool check(int op, int l, int r) {
if (l > n || r > n || (op == 2 && l == r)) {
return false;
}
if (find(l) == find(r)) {
if (op == 1) {
if (dist[l] != dist[r]) {
return false;
}
}
else {
if ((dist[l] - dist[r] + 3) % 3 != 1) {
return false;
}
}
}
return true;
}
int main() {
FILE* fp = stdin;
fscanf(fp, "%d%d", &n, &k);
prepare();
int op, l, r;
for (int i = 1; i <= k; ++i) {
fscanf(fp, "%d%d%d", &op, &l, &r);
if (!check(op, l, r)) {
ans++;
}
else {
union_(op, l, r);
}
}
printf("%d\n", ans);
return 0;
}