题解【POJ3463】Sightseeing

题面

题意简述:

求最短路和比最短路大 \(1\) 的路径条数。

考虑在 Dijkstra 算法时同时求出最短路和次短路,以及它们的条数。

于是我们改变一下堆中存储的数据,多存储一下这个点的类型(最短路或次短路)。

然后在枚举点的时候分类讨论一下就好了。

#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 1003, M = 10003;

int n, m, t;
int tot, head[N], ver[M], nxt[M], edge[M];
int S, T;
int dist[N][2], cnt[N][2];
bool vis[N][2];

struct Node
{
    int ver, ty, dis; //点的编号、类型(1 为次短路,0 为最短路) 和 从 1 到当前点的距离
    bool operator > (const Node &a) const
    {
        return dis > a.dis; //重载运算符
    }
} ;

inline void add(int u, int v, int w)
{
    ver[++tot] = v, edge[tot] = w, nxt[tot] = head[u], head[u] = tot;
}

inline int Dij()
{
    memset(cnt, 0, sizeof cnt);
    memset(dist, 0x3f, sizeof dist);
    memset(vis, false, sizeof vis);
    dist[S][0] = 0, cnt[S][0] = 1; //初始化时只有最短路
    priority_queue <Node, vector <Node>, greater <Node> > q;
    q.push((Node){S, 0, 0});
    while (!q.empty())
    {
        Node t = q.top(); q.pop();
        int u = t.ver, ty = t.ty, dis = t.dis, cntu = cnt[u][ty];
        if (vis[u][ty]) continue;
        vis[u][ty] = true;
        for (int i = head[u]; i; i = nxt[i])
        {
            int v = ver[i], w = edge[i];
            if (dist[v][0] == dis + w) cnt[v][0] += cntu; //与最短路长度相同
            else if (dist[v][0] > dis + w) //比最短路还短
            {
                dist[v][1] = dist[v][0], cnt[v][1] = cnt[v][0]; //先更新次短路为当前的最短路
                q.push((Node){v, 1, dist[v][1]}); //放入堆中
                dist[v][0] = dis + w, cnt[v][0] = cntu; //更新最短路
                q.push((Node){v, 0, dist[v][0]}); //将最短路放入堆中
            }
            else if (dist[v][1] == dis + w) cnt[v][1] += cntu; //与次短路长度相同
            else if (dist[v][1] > dis + w) //比次短路短
            {
                dist[v][1] = dis + w, cnt[v][1] = cntu;
                q.push((Node){v, 1, dist[v][1]});
            }
        }
    }
    int ans = cnt[T][0];
    if (dist[T][0] + 1 == dist[T][1]) //存在比最短路长度多 1 的次短路
        ans += cnt[T][1];
    return ans;
}

int main()
{
    cin >> t;
    while (t--)
    {
        memset(head, 0, sizeof head);
        tot = 0;
        cin >> n >> m;
        for (int i = 1; i <= m; i+=1)
        {
            int u, v, w;
            cin >> u >> v >> w;
            add(u, v, w); //注意是单向边
        }
        cin >> S >> T;
        cout << Dij() << endl;
    }
    return 0;
}
posted @ 2020-03-01 15:25  csxsi  阅读(136)  评论(0编辑  收藏  举报