非常经典的一道题,第二次做又调了半天。
做法:tarjan缩点+最短路+差分约束
首先看到题第一眼就知道是个差分约束题,然后题中求
x
>
=
1
x>=1
x>=1 的最小解,就可以想到最长路了。先对题中给出的限制条件进行转化。
X
=
1
X=1
X=1 则连边
a
d
d
(
a
,
b
,
0
)
,
a
d
d
(
b
,
a
,
0
)
add(a,b,0),add(b,a,0)
add(a,b,0),add(b,a,0)。
X
=
2
X=2
X=2 则连边
a
d
d
(
a
,
b
,
1
)
add(a,b,1)
add(a,b,1)。
X
=
3
X=3
X=3 则连边
a
d
d
(
b
,
a
,
0
)
add(b,a,0)
add(b,a,0)。
X
=
4
X=4
X=4 则连边
a
d
d
(
b
,
a
,
1
)
add(b,a,1)
add(b,a,1)。
X
=
5
X=5
X=5 则连边
a
d
d
(
a
,
b
,
0
)
add(a,b,0)
add(a,b,0)。
然后建立超级源点
0
0
0,对每个点连一条边权为
1
1
1 的边,再跑一遍tarjan,缩完点后在DAG上求最长路。这些就是比较套路的操作了,相信大家都会做了。
参考代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e5+10;
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') f=-1;
ch=getchar();
}
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
int n,m,head[N],tot,dist[N],he[N],tot2;
int dfn[N],low[N],tim,scc[N],cnt,stk[N],top,sz[N],vis[N];
struct node{
int to,nxt,w;
}edge[N*2],e[N*2];
void add(int x,int y,int w){
edge[++tot].to=y;
edge[tot].w=w;
edge[tot].nxt=head[x];
head[x]=tot;
}
void add2(int x,int y,int w){
e[++tot2].to=y;
e[tot2].w=w;
e[tot2].nxt=he[x];
he[x]=tot2;
}
void tarjan(int x){
dfn[x]=low[x]=++tim,stk[++top]=x,vis[x]=1;
for(int i=head[x];i;i=edge[i].nxt){
int y=edge[i].to;
if(!dfn[y]){
tarjan(y);
low[x]=min(low[x],low[y]);
}
else if(vis[y]) low[x]=min(low[x],dfn[y]);
}
if(dfn[x]==low[x]){
int y;cnt++;
do{
y=stk[top--],vis[y]=0,scc[y]=cnt,sz[cnt]++;
}while(x!=y);
}
}
int main(){
n=read(),m=read();
for(int i=1;i<=m;i++){
int opt=read(),a=read(),b=read();
if(opt==1) add(a,b,0),add(b,a,0);
else if(opt==2) add(a,b,1);
else if(opt==3) add(b,a,0);
else if(opt==4) add(b,a,1);
else if(opt==5) add(a,b,0);
}
for(int i=1;i<=n;i++) add(0,i,1);
tarjan(0);
for(int x=0;x<=n;x++){//x记得从0开始
for(int i=head[x];i;i=edge[i].nxt){
int y=edge[i].to;
if(scc[x]!=scc[y]) add2(scc[x],scc[y],edge[i].w);
else if(edge[i].w>0){puts("-1");return 0;}
}
}
for(int x=cnt;x>=1;x--){
for(int i=he[x];i;i=e[i].nxt){
int y=e[i].to;
if(dist[y]<dist[x]+e[i].w) dist[y]=dist[x]+e[i].w;
}
}
ll ans=0;
for(int i=1;i<=cnt;i++) ans+=(ll)sz[i]*dist[i];
cout<<ans<<endl;
return 0;
}