SPOJ6717 Two Paths 树形dp


首先有朴素的\(O(n^2)\)想法

首先枚举断边,之后对于断边之后的两棵子树求出直径

考虑优化这个朴素的想法

考虑换根\(dp\)

具体而言,首先求出\(f[i], fs[i]\)表示\(i\)号点向下的最长链以及\(i\)号子树内部最长的直径

并且在求出\(g[i]\)表示\(fa[i]\)\(i\)号节点子树外的最长链

\(gs[i]\)表示\(i\)号节点子树外的直径

对于所有的\(fs[i] * gs[i]\)\(max\)即为答案


首先一遍\(dfs\)\(f, fs\)求出来

考虑怎么求\(gs[o]\),有以下几种可能

  • 一条\(fa[o]\)引申出去的链 + \(fa[o]\) 除了\(o\)子树以外的最长链

  • \(gs[fa]\)

  • 两条\(fa[o]\)除了\(o\)子树以外的链的和

用前缀和后缀来做到删除


#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

#define ll long long
#define ri register int
#define rep(io, st, ed) for(ri io = st; io <= ed; io ++)
#define drep(io, ed, st) for(ri io = ed; io >= st; io --)

#define tpr template <typename ra>
tpr inline void cmin(ra &a, ra b) { if(a > b) a = b; }
tpr inline void cmax(ra &a, ra b) { if(a < b) a = b; }
    
#define gc getchar
inline int read() {
    int p = 0, w = 1; char c = gc();
    while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
    while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
    return p * w;
}
    
const int sid = 200050;
    
ll ans;
int n, cnp;
int cap[sid], nxt[sid], node[sid];
    
inline void addedge(int u, int v) {
    nxt[++ cnp] = cap[u]; cap[u] = cnp; node[cnp] = v;
}
    
int f[sid], fs[sid];
int g[sid], gs[sid], q[sid], ps[sid], ss[sid];
    
#define cur node[i]
void dfs1(int o, int fa) {
    for(int i = cap[o]; i; i = nxt[i])
    if(cur != fa) {
        dfs1(cur, o);
        cmax(fs[o], fs[cur]);
        cmax(fs[o], f[cur] + f[o] + 1);
        cmax(f[o], f[cur] + 1);
    }
}

void dfs2(int o, int fa) {
    
    if(!fa) g[o] = -1;
    
    int tot = 0;
    int sx = 0, mx = 0, px = 0;
    
    for(int i = cap[o]; i; i = nxt[i])
        if(cur != fa) q[++ tot] = cur;
    
    rep(i, 1, tot + 1) ps[i] = ss[i] = 0;
    
    rep(i, 1, tot) {
        int p = q[i];
        cmax(g[p], g[o] + 1);
        cmax(g[p], ps[i - 1]);
        ps[i] = max(ps[i - 1], f[p] + 1);
    }
    
    drep(i, tot, 1) {
        int p = q[i];
        cmax(g[p], ss[i +  1]);
        ss[i] = max(ss[i + 1], f[p] + 1);
    }
    
    rep(i, 1, tot) {
        int p = q[i];
        cmax(gs[p], ps[i - 1] + ss[i + 1]);
        cmax(gs[p], gs[o]);
        cmax(gs[p], g[o] + 1 + ps[i - 1]);
        cmax(gs[p], g[o] + 1 + ss[i + 1]);
    }
    
    rep(i, 1, tot) {
        int p = q[i];
        cmax(gs[p], mx + sx);
        cmax(gs[p], px); cmax(px, fs[p]);
        if(f[p] + 1 >= mx) sx = mx, mx = f[p] + 1;
        else if(f[p] + 1 >= sx) sx = f[p] + 1;
    }
    
    mx = sx = px = 0;
    
    drep(i, tot, 1) {
        int p = q[i];
        cmax(gs[p], mx + sx);
        cmax(gs[p], px); cmax(px, fs[p]);
        if(f[p] + 1 >= mx) sx = mx, mx = f[p] + 1;
        else if(f[p] + 1 >= sx) sx = f[p] + 1;
    }
    
    for(int i = cap[o]; i; i = nxt[i])
        if(cur != fa) {
            cmax(ans, 1ll * fs[cur] * gs[cur]);
            dfs2(cur, o);
        }
}

int main() {
    n = read();
    rep(i, 2, n) {
        int u = read(), v =read();
        addedge(u, v); addedge(v, u);
    }
    dfs1(1, 0); dfs2(1, 0);
    printf("%lld\n", ans);
    return 0;
}
posted @ 2018-11-27 21:52  remoon  阅读(198)  评论(0编辑  收藏  举报