最大流模板【Luogu3376】(dinic)
#include <cstring>
#include <queue>
using namespace std;
typedef long long ll;
const int maxn = 10005;
const int maxm = 200005;
const ll inf = 0x3f3f3f3f3f3f3f3f;
int n, m, src, snk, dis[maxn];
int tot, ter[maxm], nxt[maxm], lnk[maxn];
ll wei[maxm];
void adde(int u, int v, int w) {
ter[tot] = v;
wei[tot] = w;
nxt[tot] = lnk[u];
lnk[u] = tot++;
}
bool bfs(int s, int t) {
queue<int> que;
que.push(s);
memset(dis, -1, sizeof(dis));
dis[s] = 0;
ll w;
for (int u, v; !que.empty(); ) {
u = que.front();
que.pop();
for (int i = lnk[u]; ~i; i = nxt[i]) {
v = ter[i], w = wei[i];
if (w && dis[v] == -1) {
dis[v] = dis[u] + 1;
que.push(v);
}
}
}
return ~dis[t];
}
ll find(int u, ll lft) {
if (u == snk) {
return lft;
}
ll w, tmp, res = 0;
for (int v, i = lnk[u]; ~i && res < lft; i = nxt[i]) {
v = ter[i], w = wei[i];
if (w && dis[u] + 1 == dis[v]) {
tmp = find(v, min(w, lft - res));
if (tmp) {
wei[i] -= tmp;
wei[i ^ 1] += tmp;
res += tmp;
}
}
}
if (res < lft) {
dis[u] = -1;
}
return res;
}
ll dinic() {
ll tmp, res = 0;
while (bfs(src, snk)) {
tmp = find(src, inf);
while (tmp) {
res += tmp;
tmp = find(src, inf);
}
}
return res;
}
int main() {
memset(lnk, -1, sizeof(lnk));
scanf("%d %d %d %d", &n, &m, &src, &snk);
ll w;
for (int u, v, i = 1; i <= m; i++) {
scanf("%d %d %lld", &u, &v, &w);
adde(u, v, w), adde(v, u, 0);
}
printf("%lld\n", dinic());
return 0;
}
费用流模版【Luogu3381】(多路增广)
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
typedef long long ll;
const int maxn = 5005;
const int maxm = 100005;
const ll inf = 0x3f3f3f3f3f3f3f3f;
bool vis[maxn];
int n, m, src, snk;
int tot, ter[maxm], nxt[maxm], lnk[maxn];
ll ans, len[maxm], wei[maxm], dis[maxn];
void adde(int u, int v, int w0, int w1) {
ter[tot] = v;
len[tot] = w0;
wei[tot] = w1;
nxt[tot] = lnk[u];
lnk[u] = tot++;
}
bool spfa(int s, int t) {
queue<int> que;
que.push(s);
memset(dis, 0x3f, sizeof(dis));
dis[s] = 0, vis[s] = 1;
ll w0, w1;
for (int u, v; !que.empty(); ) {
u = que.front();
que.pop();
vis[u] = 0;
for (int i = lnk[u]; ~i; i = nxt[i]) {
v = ter[i], w0 = len[i], w1 = wei[i];
if (w0 && dis[v] > dis[u] + w1) {
dis[v] = dis[u] + w1;
if (!vis[v]) {
vis[v] = 1;
que.push(v);
}
}
}
}
return dis[t] < inf;
}
ll find(int u, ll lft) {
if (u == snk) {
return lft;
}
vis[u] = 1;
ll w0, w1, tmp, res = 0;
for (int v, i = lnk[u]; ~i && res < lft; i = nxt[i]) {
v = ter[i], w0 = len[i], w1 = wei[i];
if (w0 && !vis[v] && dis[u] + w1 == dis[v]) {
tmp = find(v, min(lft - res, w0));
if (tmp) {
ans += tmp * w1;
len[i] -= tmp;
len[i ^ 1] += tmp;
res += tmp;
}
}
}
vis[u] = 0;
return res;
}
ll mcmf() {
ll tmp, res = 0;
while (spfa(src, snk)) {
tmp = find(src, inf);
while (tmp) {
res += tmp;
tmp = find(src, inf);
}
}
return res;
}
int main() {
memset(lnk, -1, sizeof(lnk));
scanf("%d %d %d %d", &n, &m, &src, &snk);
ll w0, w1;
for (int u, v, i = 1; i <= m; i++) {
scanf("%d %d %lld %lld", &u, &v, &w0, &w1);
adde(u, v, w0, w1), adde(v, u, 0, -w1);
}
printf("%lld ", mcmf());
printf("%lld\n", ans);
return 0;
}