Description
有N个村庄坐落在一条直线上,第i(i>1)个村庄距离第1个村庄的距离为Di。需要在这些村庄中建立不超过K个通讯基站,在第i个村庄建立基站的费用为Ci。如果在距离第i个村庄不超过Si的范围内建立了一个通讯基站,那么就成它被覆盖了。如果第i个村庄没有被覆盖,则需要向他们补偿,费用为Wi。现在的问题是,选择基站的位置,使得总费用最小。 输入数据 (base.in) 输入文件的第一行包含两个整数N,K,含义如上所述。 第二行包含N-1个整数,分别表示D2,D3,…,DN ,这N-1个数是递增的。 第三行包含N个整数,表示C1,C2,…CN。 第四行包含N个整数,表示S1,S2,…,SN。 第五行包含N个整数,表示W1,W2,…,WN。
Input
输出文件中仅包含一个整数,表示最小的总费用。
Output
3 2 1 2 2 3 2 1 1 0 10 20 30
Sample Input
4
Sample Output
40%的数据中,N<=500;
100%的数据中,K<=N,K<=100,N<=20,000,Di<=1000000000,Ci<=10000,Si<=1000000000,Wi<=10000。
题解:
设f[i][k]表示到第i个村庄,第i个村庄一定会建基站,已经建了k个基站的最小费用.
f[i][k]=min{f[j][k-1]+solve(j+1,i-1)}+c[i];
solve(x,y)表示x到y这一段的最小补偿费用.这个dp是O(n^3)的,显然tle.
主要的瓶颈在于solve(x,y)的计算.考虑y增加对solve(x,y)的影响.
原来被左端点覆盖的没有影响,被右端点覆盖的会减少并且不会再被覆盖.
考虑用线段树优化这个dp.用线段树维护f[x]+solve(x+1,y)的最小值.
设l[x],r[x]表示在[l[x]+1,r[x]-1]范围内建造基站村庄x能被覆盖.
每次处理完x位置,枚举所有r[a]=x的村庄a,然后在线段树中把[1,l[a]]都加上w[a].
因为a村庄已经无法被右端点覆盖,所以这些当做左端点也无法覆盖a的村庄的f+solve值肯定要增加w[a];
为了实现这一过程,可以用链表来维护r[x]相同的村庄.
枚举k,每次都要重新建树.我们可以把n和k加1,这样每次最优值都储存在f[n]中.
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define N 200010
#define K 110
#define inf 1000000000
using namespace std;
int n,k,st[N],en[N],next[N],point[N],d[N],c[N],s[N],w[N];
long long t[N<<2],p[N<<2],f[N];
vector<int>q[N];
int read(){
int x(0);char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
int find(int x){
int l=1,r=n,ans;
if (x<=0) return 0;
while (l<=r){
int mid=(l+r)>>1;
if (d[mid]<x){ans=mid;l=mid+1;}
else r=mid-1;
}
return ans;
}
void build(int k,int l,int r){
int mid=(l+r)>>1;
if (l==r){t[k]=f[l];return;}
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
t[k]=min(t[k<<1],t[k<<1|1]);
}
void paint(int k,int l,int r,long long v){t[k]+=v;p[k]+=v;}
void pushdown(int k,int l,int r,long long v){
int mid=(l+r)>>1;
paint(k<<1,l,mid,v);
paint(k<<1|1,mid+1,r,v);
p[k]=0;
}
void add(int k,int l,int r,int ll,int rr,long long v){
if (ll>rr) return;
int mid=(l+r)>>1;
if (ll<=l&&r<=rr) {paint(k,l,r,v);return;}
if (p[k]) pushdown(k,l,r,p[k]);
if (ll<=mid) add(k<<1,l,mid,ll,rr,v);
if (mid<rr) add(k<<1|1,mid+1,r,ll,rr,v);
t[k]=min(t[k<<1],t[k<<1|1]);
}
int query(int k,int l,int r,int ll,int rr){
if (ll>rr) return 0;
int mid=(l+r)>>1;
if (l==ll&&r==rr) return t[k];
if (p[k]) pushdown(k,l,r,p[k]);
if (rr<=mid) return query(k<<1,l,mid,ll,rr);
else if (mid<ll) return query(k<<1|1,mid+1,r,ll,rr);
else return min(query(k<<1,l,mid,ll,mid),query(k<<1|1,mid+1,r,mid+1,rr));
}
void solve(){
long long temp(0);
long long ans=inf;
for (int i=1;i<=n;i++){
f[i]=(long long)(temp+c[i]);
for (int o=point[i];o;o=next[o])
temp+=w[o];
}
ans=f[n];
for (int i=2;i<=k;i++){
build(1,1,n);memset(p,0,sizeof(p));
for (int j=1;j<=n;j++){
f[j]=query(1,1,n,1,j-1)+c[j];
for (int o=point[j];o;o=next[o])
add(1,1,n,1,st[o],w[o]);
}
ans=min(ans,f[n]);
}
cout<<ans<<endl;
}
int main(){
n=read();k=read();
for (int i=2;i<=n;i++) d[i]=read();
for (int i=1;i<=n;i++) c[i]=read();
for (int i=1;i<=n;i++) s[i]=read();
for (int i=1;i<=n;i++) w[i]=read();
n=n+1;k=k+1;d[n]=inf;w[n]=inf;
for (int i=1;i<=n;i++){
int l=d[i]-s[i],r=d[i]+s[i];
st[i]=find(l);en[i]=find(r);
if (d[en[i]+1]==r) en[i]+=1;
next[i]=point[en[i]];point[en[i]]=i;
}
solve();
}