题目大意
n ( ≤ 2000 ) n(\le 2000) n(≤2000) 个点 m ( ≤ 5000 ) m(\le 5000) m(≤5000) 条有向边的图,用Floyd算法求最短路,如果把松弛点的循环放在内层,问有多少组答案依然正确。
思路
可以先
O
(
n
m
log
n
)
O(nm\log n)
O(nmlogn) 求所有点对间的最短路。
考虑错误的 Floyd 求
s
s
s 到
t
t
t 的最短路。能得到正解当且仅当有一个有效的松弛点
v
v
v 使得
v
v
v 在最短路上且
s
s
s 到
v
v
v,
v
v
v 到
t
t
t 均得到正解。
固定
u
u
u ,可以用 bitset 维护到
v
v
v 的最短路会经过哪些点。
另外用 bitset 维护每个点作为起点和终点到达的哪些点是正解。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
using namespace std;
const int N = 2005;
const int M = 5005;
const int inf = 0x3fffffff;
int n, m;
int hd[N];
class edge {
public:
int to, val, nxt;
edge() {}
} e[M];
void add(int u, int v, int w, int i) {
e[i].to = v;
e[i].val = w;
e[i].nxt = hd[u];
hd[u] = i;
}
int dis[N][N];
bitset<N> from[N], to[N], pot[N];
class node {
public:
int num, dis;
node() {}
node(int _num, int _dis) : num(_num), dis(_dis) {}
bool operator<(const node &rhs) const { return dis > rhs.dis; }
};
void dijkstra(int uu) {
priority_queue<node> q;
rep(i, 1, n) { dis[uu][i] = inf; }
q.push(node(uu, dis[uu][uu] = 0));
while (!q.empty()) {
node now = q.top();
q.pop();
int u = now.num;
if (dis[uu][u] != now.dis) continue;
// printf("%d", u);
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (dis[uu][u] + e[i].val < dis[uu][v]) {
q.push(node(v, dis[uu][v] = dis[uu][u] + e[i].val));
}
}
}
for (int i = hd[uu]; i; i = e[i].nxt) {
int v = e[i].to;
if (e[i].val == dis[uu][v]) {
from[uu][v] = to[v][uu] = 1;
}
}
}
void explore(int uu) {
static int f[N], g[N];
rep(i, 1, n) {
pot[i].reset();
pot[i][i] = 1;
}
rep(i, 1, n) f[i] = i, g[i] = dis[uu][i];
sort(f + 1, f + n + 1, [](int i, int j) { return g[i] < g[j]; });
rep(ui, 1, n) {
int u = f[ui];
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (g[u] + e[i].val == g[v]) {
pot[v] |= pot[u];
}
}
}
rep(i, 1, n) {
if (g[i] == inf || (pot[i] & to[i] & from[uu]).any()) {
from[uu][i] = to[i][uu] = 1;
}
}
}
int main() {
scanf("%d%d", &n, &m);
rep(i, 1, n) {
hd[i] = 0;
from[i].reset();
to[i].reset();
from[i][i] = to[i][i] = 1;
}
rep(i, 1, m) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w, i);
// from[u][v] = to[v][u] = 1;
}
rep(u, 1, n) { dijkstra(u); }
/*
rep(i, 1, n) {
rep(j, 1, n) { printf("%d ", dis[i][j]); }
printf("\n");
}
*/
rep(u, 1, n) { explore(u); }
int ans = 0;
rep(u, 1, n) { ans += from[u].count(); }
printf("%d", ans);
return 0;
}