题意:n个点m条边, 求a到b最短路条数。
题解:求出a点到所有点的距离, b点到所有点距离, a到b最短路。建一个图, 枚举所有的边,如果此条边在最短路上(dist[a] + dist[b] + wab == mindistab)则加入图中, 容量为1, 超级汇点s连a ,t连b,用dinic 跑一边最大流即可。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;
const int MaxN = 2e4;
const int MaxM = 2222222;
struct node
{
int v, w, next;
}e[MaxM];
struct Node
{
int v, rev;
int cap;
Node(int v, int cap, int rev):v(v), cap(cap), rev(rev){};
};
int n, m, cnt;
int g[MaxN + 1], a[MaxN + 1];
int dist[MaxN + 1], d1[MaxN + 1], d2[MaxN + 1];
int u[MaxM + 1], v[MaxM + 1], w[MaxM + 1];
vector<Node> f[MaxN + 1];
int cur[MaxN + 1];
void addedge(int u, int v, int w)
{
e[cnt].v = v;
e[cnt].w = w;
e[cnt].next = g[u];
g[u] = cnt++;
}
void add(int u, int v, int cap)
{
f[u].push_back(Node(v, cap, f[v].size()));
f[v].push_back(Node(u, 0, f[u].size() - 1));
}
void spfa(int s)
{
memset(dist, 0x3f, sizeof(dist));
dist[s] = 0;
bool vis[MaxN + 1] = {0};
vis[s] = true;
queue<int> q;
q.push(s);
while (!q.empty())
{
int t = q.front(); q.pop();
vis[t] = false;
for (int i = g[t]; i != -1; i = e[i].next)
{
int v = e[i].v;
if (dist[v] > dist[t] + e[i].w)
{
dist[v] = dist[t] + e[i].w;
if (!vis[v])
{
vis[v] = true;
q.push(v);
}
}
}
}
}
bool bfs(int s, int t)
{
memset(dist, -1, sizeof(dist));
queue<int> q;
dist[s] = 0;
q.push(s);
while (!q.empty())
{
int x = q.front(); q.pop();
if (x == t)
return true;
for (int i = 0; i < f[x].size(); i++)
{
Node &e = f[x][i];
if (e.cap > 0 && dist[e.v] < 0)
{
dist[e.v] = dist[x] + 1;
q.push(e.v);
}
}
}
return false;
}
int dfs(int u, int v, int flow)
{
if (u == v)
return flow;
for (int &i = cur[u]; i < f[u].size(); i++)
{
Node &e = f[u][i];
int d;
if (e.cap > 0 && dist[u] + 1 == dist[e.v] && (d = dfs(e.v, v, min(flow, e.cap))) > 0)
{
e.cap -= d;
f[e.v][e.rev].cap += d;
return d;
}
}
return 0;
}
int maxflow(int s, int t)
{
int flow = 0, f;
while (bfs(s, t))
{
memset(cur, 0, sizeof(cur));
while ((f = dfs(s, t, 0x3f3f3f3f)) > 0)
flow += f;
}
return flow;
}
int main()
{
int k = 0;
scanf("%d", &k);
while (k--)
{
scanf("%d %d", &n, &m);
memset(g, -1, sizeof(g));
cnt = 0;
for (int i = 1; i <= m; i++)
{
scanf("%d %d %d", &u[i], &v[i], &w[i]);
if (u[i] == v[i]) continue;
addedge(u[i], v[i], w[i]);
}
int s, t;
scanf("%d %d", &s, &t);
spfa(s);
int ans = dist[t];
for (int i = 1; i <= n; i++)
d1[i] = dist[i];
memset(g, -1, sizeof(g));
cnt = 0;
for (int i = 1; i <= m; i++)
if (v[i] != u[i])
addedge(v[i], u[i], w[i]);
spfa(t);
for (int i = 1; i <= n; i++)
d2[i] = dist[i];
for (int i = 0; i <= n + 1; i++)
f[i].clear();
for (int i = 1; i <= m; i++)
if (d1[u[i]] + w[i] + d2[v[i]] == ans)
add(u[i], v[i], 1);
int S = 0, T = n + 1;
add(S, s, 0x3f3f3f3f);
add(t, T, 0x3f3f3f3f);
printf("%d\n", maxflow(S, T));
}
return 0;
}