CF传送门
洛谷传送门
BZOJ传送门
题解:
首先我们发现只需要考虑所有的 ( u , v ) (u,v) (u,v)和 ( v , u ) (v,u) (v,u)的飞行计划就行了。
而在考虑 ( u , v ) (u,v) (u,v)的时候,其他的并不重要。
我们也只需要考虑所有 ( u , v ) (u,v) (u,v)和 ( v , u ) (v,u) (v,u)的相对顺序。
现在问题变为:
有一个 01 01 01序列(即 ( u , v ) (u,v) (u,v)和 ( v , u ) (v,u) (v,u)),有如下四种操作:
- 花费 v 1 v_1 v1代价消除一个 0 0 0,即 ( u , v ) (u,v) (u,v)单程
- 花费 v 2 v_2 v2代价消除一个 0 0 0,并消除一个它后面的 1 1 1,即 ( u , v ) (u,v) (u,v)往返
- 花费 v 3 v_3 v3代价消除一个 1 1 1,即 ( v , u ) (v,u) (v,u)单程
- 花费 v 4 v_4 v4代价消除一个 1 1 1,并消除一个它后面的 0 0 0,即 ( v , u ) (v,u) (v,u)往返。
问消空这个序列的最小代价。
显然当单程花费大于往返的时候,我们会考虑用往返票代替单程。
并且当往返大于两个单程之和的时候,我们会考虑用两个单程代替往返。
显然可以先进行如下的赋值操作:
v1=min(v1,v2);v3=min(v3,v4);
v2=min(v2,v1+v3);v4=min(v4,v1+v3);
假定 v 2 ≤ v 4 v_2\leq v_4 v2≤v4,此时必然有 v 2 ≤ v 4 ≤ v 1 + v 3 v_2\leq v_4\leq v_1+v_3 v2≤v4≤v1+v3
那么贪心先消 01 01 01,然后消 10 10 10 ,剩下的单消就行了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define gc get_char
#define cs const
namespace IO{
static cs int Rlen=1<<22|1;
static char buf[Rlen],*p1,*p2;
inline char get_char(){return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;}
inline char peek(){return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1;}
inline char get_sth(){while(isspace(peek()))gc();return gc();}
template<typename T>
inline T get(){
char c;
while(!isdigit(c=gc()));T num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
inline int getint(){return get<int>();}
}
using namespace IO;
using std::cerr;
using std::cout;
cs int N=5e5+5;
int n,d,m;
std::unordered_map<ll,int> ma;int id;
inline int get_id(int u,int v){
return ma.count((ll)u*n+v)?ma[(ll)u*n+v]:(ma[(ll)u*n+v]=++id);
}
int nd[N];
bool sol[N];
std::vector<int> pos[N];
int w1[N],w2[N];
ll ans;
int ban1[N],ban2[N];
signed main(){
// freopen("jazz.in","r",stdin);//freopen("jazz.out","w",stdout);
n=getint(),d=getint();nd[1]=getint();
for(int re i=2;i<=d;++i){
nd[i]=getint();
pos[get_id(nd[i-1],nd[i])].push_back(i);
}
memset(w1,0x7f,sizeof w1);
memset(w2,0x7f,sizeof w2);
m=getint();
while(m--){
int u=getint(),v=getint();
char c=get_sth();int w=getint();
if(!ma.count((ll)u*n+v))continue;
int t=ma[(ll)u*n+v],&tmp=c=='O'?w1[t]:w2[t];
tmp=std::min(tmp,w);
}
for(int re tt=2;tt<=d;++tt){
int i1=ma[(ll)nd[tt-1]*n+nd[tt]];
if(sol[i1])continue;
sol[i1]=true;
std::vector<int> t1=pos[i1];
int v1=w1[i1],v2=w2[i1];
int i2=ma.count((ll)nd[tt]*n+nd[tt-1])?ma[(ll)nd[tt]*n+nd[tt-1]]:0;
if(!i2){
ans+=std::min(v1,v2)*(ll)t1.size();
continue;
}
sol[i2]=true;
std::vector<int> t2=pos[i2];
int v3=w1[i2],v4=w2[i2];
v1=std::min(v1,v2);v3=std::min(v3,v4);
v2=std::min(v2,v1+v3);v4=std::min(v4,v1+v3);
if(v2>v4){
std::swap(t1,t2);
std::swap(v2,v4);
std::swap(v1,v3);
}
ll tmp=0;
int re i=0,j=0,c1=t1.size(),c2=t2.size();
while(i<t1.size()&&j<t2.size()){
if(t1[i]<t2[j]){
ban1[i++]=ban2[j++]=tt;
--c1,--c2;
tmp+=v2;
}
else ++j;
}
i=0,j=0;
while(i<t1.size()&&j<t2.size()){
if(ban1[i]==tt){++i;continue;}
if(ban2[j]==tt){++j;continue;}
if(t2[j]<t1[i]){
++i,++j;
--c1,--c2;
tmp+=v4;
}
else ++i;
}
ans+=tmp+(ll)c1*v1+(ll)c2*v3;
}
cout<<ans<<"\n";
return 0;
}