[NOIp2017]逛公园

Description

Luogu3953
求与最短路相差不超过k的路径条数。

Solution

不难想到DP,一个点的路径条数是它能到的点的路径条数之和。
考虑一个点\(i\),设\(path(S,i) + dis(i, T) = dis(S,T) + k\),其中\(path\)是当前走过的路径,\(dis\)是最短路径,考虑一条边\(e(i,j,w)\),此时\(path(S,j)+dis(j,T) = path(S,i) + w + dis(j, T) = dis(S, T) + k - dis(i, T) + w + dis(j, T)\),所以超出最短路的长度变化了\(-dis(i,T)+w+dis(j,T)\),这样就可以转移了。具体看代码。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

const int N = 100010;
const int M = 400010;

int n, m, k, P;
int hd[N], to[M], nxt[M], w[M], fl[M], cnt;
int dis[N], vis[N];
int insta[N][55], f[N][55];
struct node {
    int p, d;
    node (int a=0, int b=0) : p(a), d(b) {}
    bool operator< (const node& x) const { return d > x.d; }
};

void adde(int x, int y, int z, int f) {
    cnt++;
    to[cnt] = y; nxt[cnt] = hd[x]; w[cnt] = z;
    fl[cnt] = f;
    hd[x] = cnt;
}

void dij() {
    std::priority_queue<node> q;
    q.push(node(n, dis[n]=0));
    while (!q.empty()) {
        node x = q.top(); q.pop();
        if (vis[x.p]) continue;
        vis[x.p] = 1;
        for (int i = hd[x.p]; i; i = nxt[i]) if (!vis[to[i]] && fl[i] == 1) {
            if (dis[to[i]] > dis[x.p] + w[i]) {
                q.push(node(to[i], dis[to[i]] = dis[x.p] + w[i]));
            }
        }
    }
}

int dfs(int p, int x) {
    if (x < 0) return 0;
    if (insta[p][x]) return -1;
    if (f[p][x]) return f[p][x];
    int ans = 0;
    if (p == n) ans = 1;
    insta[p][x] = 1;
    for (int i = hd[p]; i; i = nxt[i]) if (fl[i] == 0) {
        int tmp = dfs(to[i], x - w[i] - dis[to[i]] + dis[p]);
        if (tmp == -1) return -1;
        ans = (ans + tmp) % P;
    }
    f[p][x] = ans % P;
    insta[p][x] = 0;
    return ans % P;
}

void init() {
    memset(vis, 0, sizeof vis);
    memset(insta, 0, sizeof insta);
    memset(f, 0, sizeof f);
    memset(dis, 0x3f, sizeof dis);
    memset(hd, 0, sizeof hd);
    cnt = 0;
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        init();
        scanf("%d%d%d%d", &n, &m, &k, &P);
        for (int i = 1, x, y, z; i <= m; ++i) {
            scanf("%d%d%d", &x, &y, &z);
            adde(x, y, z, 0);
            adde(y, x, z, 1);
        }
        dij();
        printf("%d\n", dfs(1, k) % P);
    }
    return 0;
} 
posted @ 2018-09-17 20:01  wyxwyx  阅读(87)  评论(0编辑  收藏  举报