题意:给出一张图,求出从s点到t点的第K短路。如果没有,输出-1.
思路:因为K<=1000,所以我们可以枚举距离从小到大所有的从s到t的路,找到第K短即可。
这里我们用A*算法。对应的h函数是每个点到t点的最短距离,可以看出这是最终距离的下界。
需要注意一点的是,在A*算法中,我们需要重复的加点,但是每个点至多加K次。如果加了K+1次,那对应的路的距离至少是第K+1短路。
代码如下:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
int N,M,S,T,K;
const int MAXN = 1010;
const int MAXM = 100100;
struct edge{
int to,cost;
edge(){}
edge(int t, int c):to(t),cost(c){}
}edges1[MAXM],edges2[MAXM];
int tot1,tot2;
int head1[MAXN],head2[MAXN];
int nxt1[MAXM],nxt2[MAXM];
int dis[MAXN];
bool used[MAXN];
int cnt[MAXN];
struct node1{
int u, cost;
node1(){}
node1(int c, int uu):cost(c),u(uu){}
bool operator < (const node1 & rhs) const{
return cost > rhs.cost;
}
};
struct node2{
int u,d;
node2(){};
node2(int dd, int uu):u(uu),d(dd){};
bool operator < (const node2 & rhs) const{
return d + dis[u] > rhs.d + dis[rhs.u];
}
};
int Astar()
{
memset(cnt,0,sizeof(cnt));
priority_queue<node2> Q;
Q.push(node2(0,S));
while(!Q.empty()){
node2 x = Q.top();Q.pop();
cnt[x.u]++;
if(cnt[x.u] == K && x.u == T) return x.d;
else if(cnt[x.u] <= K){
for(int i = head1[x.u]; ~i; i = nxt1[i]){
edge & e = edges1[i];
Q.push(node2(x.d + e.cost,e.to));
}
}
}
return -1;
}
void dijstra()
{
memset(dis,0x3f,sizeof(dis));
memset(used,0,sizeof(used));
dis[T] = 0;
priority_queue<node1> Q;
Q.push(node1(0,T));
while(!Q.empty()){
node1 x = Q.top();Q.pop();
if(used[x.u]) continue;
used[x.u] = true;
for(int i = head2[x.u]; ~i; i = nxt2[i]){
edge & e = edges2[i];
if(dis[e.to] > x.cost + e.cost){
dis[e.to] = x.cost + e.cost;
Q.push(node1(dis[e.to],e.to));
}
}
}
}
void init()
{
tot1 = tot2 = 0;
memset(head1,-1,sizeof(head1));
memset(head2,-1,sizeof(head2));
}
void addedge(int u, int v, int c)
{
edges1[tot1] = edge(v,c);
nxt1[tot1] = head1[u];
head1[u] = tot1++;
edges2[tot2] = edge(u,c);
nxt2[tot2] = head2[v];
head2[v] = tot2++;
}
int main(void)
{
//freopen("input.txt","r",stdin);
scanf("%d%d", &N, &M);
init();
for(int i = 0; i < M; ++i){
int a,b,t;
scanf("%d%d%d",&a,&b,&t);
addedge(a,b,t);
}
scanf("%d%d%d",&S,&T,&K);
if(S == T) K++;
dijstra();
printf("%d\n",Astar());
return 0;
}