换根DP

换根DP

换根法套路:枚举每个点为根做一遍dp

简化为二次扫描换根法

1,随便找一个点作为根进行dp,

2,再以原来点为根进行dp,此次dp,设最优解为 f[x],那么f[root]=d[root],这是显而易见的

然后再通过找d[son]与f[x]之间关系进行dp

例1:POJ3585

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;

const int N=2e5+5;
int T,f[N],n,d[N],ans,du[N];
struct node{
    int to,nxt,w;
}e[N<<1];
int hd[N<<1],tot;
void add(int x,int y,int z) {
    e[++tot].to=y;e[tot].w=z;e[tot].nxt=hd[x];hd[x]=tot;
}
void init() {
    tot=0;
    memset(e,0,sizeof(0));
    memset(hd,0,sizeof(hd));
    memset(f,0,sizeof(f));
    memset(d,0,sizeof(d));
    memset(du,0,sizeof(du));
}
void dp(int x,int fa) {
    d[x]=0;
    for(int i=hd[x];i;i=e[i].nxt) {
        int y=e[i].to;
        if(y==fa) continue;
        dp(y,x);
        if(du[y]==1) d[x]+=e[i].w;
        else d[x]+=min(d[y],e[i].w);
    }
}
void dfs(int x,int fa) {
    for(int i=hd[x];i;i=e[i].nxt) {
        int y=e[i].to;
        if(y==fa) continue;
        if(du[x]==1) f[y]=d[y]+e[i].w;
        else f[y]=d[y]+min(f[x]-min(d[y],e[i].w),e[i].w);
        dfs(y,x);
    }    
}
int main(){
    scanf("%d",&T);
    while(T--) {
        init();
        scanf("%d",&n);
        for(int i=1,x,y,z;i<n;i++) {
            scanf("%d%d%d",&x,&y,&z);
            add(x,y,z);add(y,x,z);
            du[y]++;du[x]++;
        }
        dp(1,0);
        f[1]=d[1];
        dfs(1,0);
        ans=0;
        for(int i=1;i<=n;i++)
            ans=max(ans,f[i]);
        printf("%d\n",ans);
    }
    return 0;
}

例2:STA-Station

#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return f*x;
}
const int N=2e6;
int n;
vector<int>g[N];
long long siz[N],fa[N],dep[N];
void dfs(int x,int f){
    siz[x]=1;fa[x]=f;dep[x]=dep[f]+1;
    for(int i=0;i<g[x].size();i++){
        int y=g[x][i];
        if(y==f) continue;
        dfs(y,x);
        siz[x]+=siz[y];
    }
}
long long dp[N],ans=1;
void get_ans(int x){
    for(int i=0;i<g[x].size();i++){
        int y=g[x][i];
        if(y==fa[x]) continue;
        dp[y]=dp[x]-siz[y]+n-siz[y];
        if(dp[y]>dp[ans]) ans=y;
        get_ans(y);
    }
}
int main(){
    n=read();
    for(int i=1,x,y;i<n;i++){
        x=read();y=read();
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs(1,0);
    for(int i=1;i<=n;i++)
        dp[1]+=dep[i];
    get_ans(1);
    printf("%lld\n",ans);
    return 0;
}
posted @ 2020-08-15 15:31  ke_xin  阅读(95)  评论(0编辑  收藏  举报