演艺
题目背景:
分析:最短路 + 拓扑排序 + bitset + 哈希
感觉是本场最难的题······考虑如何为满足条件的A,B,就是在S à T的最短路DAG上,经过A的S à T的方案数与经过B的S à T之和是S à T的总方案数,并且A不能到B。考虑如何实现,首先我们以S为起点跑dijkstra,获得每一个点的dis,如果满足dis[e->u] + e->w = dis[e->v],那么这条边e就在最短路DAG上,我们可以由此进行拓扑排序,然后根据拓扑序,分别求出从S出发,只经过最短路径图上的边,到点i的方案数x,和从T出发,只经过最短路径图上的边的反向边,到点i的方案数y,那么最终的经过i的从S到T的最短路径数位x * y,因为数值过大,所以需要hash一下,这样之后,对于每一个点i,可能的可以成对的点的hash值是固定的,那么我们只需要放到一个map / hash_map中统计一下出现次数就可以了(代码使用了双哈希,所以用map方便一些),显然,这样的答案算进了某些不合法方案,因为并没有判断是否符合A不能到B这个条件,首先,我们可以将hash值标号,将每一个hash值对应的节点有哪些,将状态压到一个bitset里面,同样,利用拓扑序和bitset,求得每一个点在S到它的最短路径上会经过的点,然后将每一个方案数的hash值为点i对应的可行点的hash值的点集,与可以到达点i的点的点集的重合部分减去即可,注意这一步只能针对在最短路DAG上的点。
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
#include <ctime>
#include <map>
#include <bitset>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 50000 + 10;
const long long INF = 1000000000000000000;
int n, m, s, t, x, y, z, cnt;
long long dis[MAXN];
struct node {
int to, w;
node(int to = 0, int w = 0) : to(to), w(w) {}
inline bool operator < (const node &a) const {
return w > a.w;
}
} ;
std::vector<node> edge[MAXN];
inline void add_edge(int x, int y, int z) {
edge[x].push_back(node(y, z)), edge[y].push_back(node(x, z));
}
inline void read_in() {
R(n), R(m), R(s), R(t);
for (int i = 1; i <= m; ++i) R(x), R(y), R(z), add_edge(x, y, z);
}
inline void dijkstra(int s) {
static bool vis[MAXN];
std::priority_queue<node> q;
for (int i = 1; i <= n; ++i) dis[i] = INF;
dis[s] = 0, q.push(node(s, 0));
while (!q.empty()) {
while (!q.empty() && vis[q.top().to]) q.pop();
if (q.empty()) break ;
int cur = q.top().to;
vis[cur] = true, q.pop();
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (!vis[e->to] && dis[e->to] > dis[cur] + e->w)
dis[e->to] = dis[cur] + e->w, q.push(node(e->to, dis[e->to]));
}
}
}
bool vis[MAXN];
std::vector<int> top;
inline void dfs(int cur) {
vis[cur] = true;
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (!vis[e->to] && dis[e->to] == dis[cur] + e->w) dfs(e->to);
}
top.push_back(cur);
}
int d[MAXN];
std::pair<long long, long long> hash[MAXN], hash1[MAXN], hash2[MAXN], need;
const int mod1 = 1000000000 + 7, mod2 = 1000000000 + 9;
inline void solve_hash() {
for (int cur = 1; cur <= n; ++cur)
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (dis[e->to] == dis[cur] + e->w)
d[e->to]++;
}
for (int i = 1; i <= n; ++i) if (d[i] == 0) dfs(i);
for (int i = 1; i <= n; ++i)
hash[i] = hash1[i] = hash2[i] = std::make_pair(0, 0);
hash1[s] = std::make_pair(1, 1);
for (int i = top.size() - 1; i >= 0; --i) {
int cur = top[i];
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (dis[e->to] == dis[cur] + e->w) {
hash1[e->to].first = (hash1[e->to].first +
hash1[cur].first) % mod1;
hash1[e->to].second = (hash1[e->to].second +
hash1[cur].second) % mod2;
}
}
}
hash2[t] = std::make_pair(1, 1);
for (int i = 0; i < top.size(); ++i) {
int cur = top[i];
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (dis[e->to] == dis[cur] + e->w) {
hash2[cur].first = (hash2[cur].first +
hash2[e->to].first) % mod1;
hash2[cur].second = (hash2[cur].second +
hash2[e->to].second) % mod2;
}
}
}
for (int i = 1; i <= n; ++i) {
hash[i].first = hash1[i].first * hash2[i].first % mod1;
hash[i].second = hash1[i].second * hash2[i].second % mod2;
}
}
std::map<std::pair<long long, long long>, int> mp, id;
std::bitset<MAXN> reach[MAXN], able[MAXN], bit;
inline void solve_ans() {
long long ans = 0;
for (int i = 1; i <= n; ++i) {
need.first = (hash[t].first - hash[i].first + mod1) % mod1;
need.second = (hash[t].second - hash[i].second + mod2) % mod2;
ans += mp[need], mp[hash[i]]++;
}
for (int i = 1; i <= n; ++i) {
if (id.find(hash[i]) == id.end()) id[hash[i]] = cnt++;
able[id[hash[i]]][i] = 1;
}
for (int i = top.size() - 1; i >= 0; --i) {
int cur = top[i];
if (cur == s || cur == t) continue ;
for (int p = 0; p < edge[cur].size(); ++p) {
node *e = &edge[cur][p];
if (dis[e->to] == dis[cur] + e->w)
reach[e->to] |= reach[cur], reach[e->to][cur] = 1;
}
}
for (int cur = 1; cur <= n; ++cur) {
if (hash[cur].first == 0 && hash[cur].second == 0) continue ;
need.first = (hash[t].first - hash[cur].first + mod1) % mod1;
need.second = (hash[t].second - hash[cur].second + mod2) % mod2;
bit = (able[id[need]] & reach[cur]), ans -= bit.count();
}
std::cout << ans;
}
int main() {
freopen("b.in", "r", stdin);
freopen("b.out", "w", stdout);
read_in();
dijkstra(s);
// dfs(s);
solve_hash();
solve_ans();
return 0;
}