题目描述
传送门
题目大意:给出n个数,每个数有两权a,b。
两个数能同时选,必须满足下面至少一个条件
(1)
gcd(ai,aj)!=1
(2)不存在整数T满足
ai2+aj2=T2
求所选集合
∑b
的最大值
题解
设源汇分别为S,T,对于每个数拆成两个点xi,yi
S->xi 容量为bi
yi->T 容量为bi
xi->yi 容量为inf,两点不满足上述条件。
跑最大流求最小割即可。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>
#define N 4000003
#define inf 1000000000
#define LL long long
using namespace std;
int remain[N],point[N],nxt[N],v[N],n,tot,S,T;
int deep[N],last[N],num[N],cur[N],a[N],b[N];
void add(int x,int y,int z)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; remain[tot]=z;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; remain[tot]=0;
//cout<<x<<" "<<y<<" "<<z<<endl;
}
long long pow(long long x){
return x*x;
}
int gcd(int x,int y)
{
int r;
while (y) {
r=x%y;
x=y; y=r;
}
return x;
}
bool check(int i,int j)
{
if (gcd(a[i],a[j])!=1) return true;
long long t=pow(a[i])+pow(a[j]);
long long t1=sqrt(t);
if (t1*t1!=t) return true;
return false;
}
int addflow(int s,int t)
{
int now=t; int ans=inf;
while (now!=s) {
ans=min(ans,remain[last[now]]);
now=v[last[now]^1];
}
now=t;
while (now!=s) {
remain[last[now]]-=ans;
remain[last[now]^1]+=ans;
now=v[last[now]^1];
}
return ans;
}
void bfs(int s,int t)
{
for (int i=s;i<=t;i++) deep[i]=t;
deep[t]=0;
queue<int> p; p.push(t);
while (!p.empty()) {
int now=p.front(); p.pop();
for (int i=point[now];i!=-1;i=nxt[i])
if (deep[v[i]]==t&&remain[i^1])
deep[v[i]]=deep[now]+1,p.push(v[i]);
}
}
int isap(int s,int t)
{
int ans=0; bfs(s,t);
for (int i=1;i<=t;i++) cur[i]=point[i],num[deep[i]]++;
int now=s;
while (deep[s]<t) {
if (now==t) {
ans+=addflow(s,t);
now=s;
}
bool pd=false;
for (int i=cur[now];i!=-1;i=nxt[i])
if (deep[now]==deep[v[i]]+1&&remain[i]) {
pd=true;
last[v[i]]=i; cur[now]=i;
now=v[i]; break;
}
if (!pd) {
int minn=t;
for (int i=point[now];i!=-1;i=nxt[i])
if (remain[i]) minn=min(deep[v[i]],minn);
if (!--num[deep[now]]) break;
num[deep[now]=minn+1]++;
cur[now]=point[now];
if (now!=s) now=v[last[now]^1];
}
}
return ans;
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d",&n);
S=1; T=2*n+2;
tot=-1;
memset(point,-1,sizeof(point));
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
int sum=0;
for (int i=1;i<=n;i++) {
scanf("%d",&b[i]); sum+=b[i];
add(S,i+1,b[i]);
add(i+n+1,T,b[i]);
}
for (int i=1;i<=n;i++)
for (int j=1;j<=n;j++)
if (!check(i,j)) add(i+1,j+n+1,inf);
printf("%d\n",sum-isap(S,T)/2);
}