坑了一上午QAQ感觉自己好弱智啊
BZOJ上的题面坑死人,2N是闹哪样啊,明明是2^N,害的我还以为是水题,WA了好几次。
然后上COGS(好评)上看了下题,发现是2^N,然后论文里的省空间方法好麻烦,于是直接用vector+动态开节点水过去了。
话说我这个写得怎么这么像线段树2333333
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
const int inf=1e9;
struct Node{
int lc,rc,pa,sum[1100];
vector<int>f[1100];
}tr[2200];
int n,f[1100][1100],tot;
int bin[20];
void build(int &o,int l,int r,int depth){
o=++tot;
memset(tr[o].sum,0,sizeof(tr[o].sum));
for(int i=0;i<=r-l+1;i++)
tr[o].f[i].resize(bin[depth]);
if(l==r){
for(int i=1;i<=n;i++)
tr[o].sum[i]+=f[i][l];
return;
}
int mid=l+r>>1;
build(tr[o].lc,l,mid,depth+1);build(tr[o].rc,mid+1,r,depth+1);
for(int i=1;i<=n;i++)
tr[o].sum[i]=tr[tr[o].lc].sum[i]+tr[tr[o].rc].sum[i];
tr[tr[o].lc].pa=tr[tr[o].rc].pa=o;
}
int calclayer(int x){
int ans=0;
while(x)x=tr[x].pa,ans++;
return ans;
}
int org[1100],change[1100];
void dp(int o,int l,int r){
int k=calclayer(o);
if(l==r){
for(int j=0;j<=1;j++)
for(int i=0;i<bin[k-1];i++){
int t=!j;
tr[o].f[j][i]+=(t!=org[l])*change[l];
int x=k-1,y=o;
while(x){
int c=(tr[tr[y].pa].lc)==y;
tr[o].f[j][i]+=(((i&bin[x-1])?1:0)==t)*tr[c?(tr[tr[y].pa].rc):(tr[tr[y].pa].lc)].sum[l];
x--;y=tr[y].pa;
}
}
}else{
int mid=l+r>>1;
dp(tr[o].lc,l,mid);dp(tr[o].rc,mid+1,r);
int len=r-l+1;
for(int j=0;j<=len;j++){
int s=(j>=len-j)*bin[k-1];
for(int i=0;i<bin[k-1];i++){
tr[o].f[j][i]=inf;
for(int u=0;u<=j;u++){
if(u>len-(len>>1)||j-u>(len>>1))continue;
tr[o].f[j][i]=min(tr[o].f[j][i],tr[tr[o].lc].f[u][i|s]+tr[tr[o].rc].f[j-u][i|s]);
//printf("%d %d %d\n",u,i|s,tr[lc].f[u][i|s]+tr[rc].f[j-u][i|s]);
}
//printf("%d %d %d %d %d\n",l,r,j,i,tr[o].f[j][i]);
}
}
}
}
int main(){
//freopen("networkcost.in","r",stdin);
//freopen("networkcost.out","w",stdout);
bin[0]=1;
for(int i=1;i<20;i++)bin[i]=bin[i-1]<<1;
scanf("%d",&n);n=1<<n;
for(int i=1;i<=n;i++)scanf("%d",&org[i]);
for(int i=1;i<=n;i++)scanf("%d",&change[i]);
for(int i=1;i<n;i++){
for(int j=1;i+j<=n;j++)
scanf("%d",&f[i][i+j]);
for(int j=i+1;j<=n;j++)
f[j][i]=f[i][j];
}
int root;
build(root,1,n,0);
dp(root,1,n);
int ans=inf;
for(int i=0;i<=n;i++)
ans=min(ans,tr[root].f[i][0]);
printf("%d\n",ans);
return 0;
}