BZOJ 1040. [ZJOI2008]骑士

 

基环森林上DP。

我刚开始想的就是找到环,然后把环上每个点及它的子树缩成一个点,就变成一个环上的DP了。然后就是强制第一个不取和强制最后一个不取。

看了别人的题解发现可以不用那么麻烦,只要找到环上任意相邻的两点,强制把这条边断开,然后还是DP两次就行了。

DP方程就比较naive了。

#include <bits/stdc++.h>

namespace IO {
    char buf[1 << 21], buf2[1 << 21], a[20], *p1 = buf, *p2 = buf, hh = '\n';
    int p, p3 = -1;
    void read() {}
    void print() {}
    inline int getc() {
        return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
    }
    inline void flush() {
        fwrite(buf2, 1, p3 + 1, stdout), p3 = -1;
    }
    template <typename T, typename... T2>
    inline void read(T &x, T2 &... oth) {
        T f = 1; x = 0;
        char ch = getc();
        while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getc(); }
        while (isdigit(ch)) { x = x * 10 + ch - 48; ch = getc(); }
        x *= f;
        read(oth...);
    }
    template <typename T, typename... T2>
    inline void print(T x, T2... oth) {
        if (p3 > 1 << 20) flush();
        if (x < 0) buf2[++p3] = 45, x = -x;
        do {
            a[++p] = x % 10 + 48;
        } while (x /= 10);
        do {
            buf2[++p3] = a[p];
        } while (--p);
        buf2[++p3] = hh;
        print(oth...);
    }
}

#define ll long long
const int N = 1e6 + 7;
int n, fa[N], root, _root;
ll dp[N][2], val[N];
bool vis[N];
int head[N];
struct E {
    int v, ne;
} e[N << 1];
int cnt = 1, ban;

void add(int u, int v) {
    e[++cnt].v = v; e[cnt].ne = head[u]; head[u] = cnt;
}

void dfs(int u, int f) {
    vis[u] = 1;
    for (int i = head[u]; i; i = e[i].ne) {
        int v = e[i].v;
        if (v == f) continue;
        if (vis[v]) {
            root = u, _root = v;
            ban = i;
        } else {
            dfs(v, u);
        }
    }
}

void DP(int u, int f) {
    dp[u][0] = 0; dp[u][1] = val[u];
    for (int i = head[u]; i; i = e[i].ne) {
        int v = e[i].v;
        if (v == f || i == ban || i == (ban ^ 1)) continue;
        DP(v, u);
        dp[u][0] += std::max(dp[v][0], dp[v][1]);
        dp[u][1] += dp[v][0];
    }
}

int main() {
    IO::read(n);
    for (int i = 1, u; i <= n; i++)
        IO::read(val[i], u), add(u, i), add(i, u);
    ll ans = 0;
    for (int i = 1; i <= n; i++) if (!vis[i]) {
        dfs(i, -1);
        DP(root, -1);
        ll temp = dp[root][0];
        DP(_root, -1);
        temp = std::max(temp, dp[_root][0]);
        ans += temp;
    }
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2020-01-25 18:00  Mrzdtz220  阅读(103)  评论(0编辑  收藏  举报