HDU5293 Tree chain problem (LCA+树链剖分+线段树)

要选出价值最大的链且不相交。

比较朴素的想法就是树dp,对于一个子树,如果这个点在一条链上,我们可以考虑是否选这个链,然后把这条链上的点单独考虑,对于剩下的子树直接求和即可

但是由于两个点不一定在同一条到根的路径上,因此我们对于每条链,都把他存到lca的位置上再考虑

画一下图,我们可以发现只要在当前点维护一个子树和-当前节点的答案,这样答案就可以通过贡献来做。

维护这个信息的原因是,我们需要子节点的答案,但是却要刨除在链上的那些点的答案

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pll;
const int N=2e5+10;
const int mod=998244353;
int h[N],ne[N],e[N],idx,f[N][25];
int ff[N],depth[N];
int dfn[N],times,top[N],sz[N],son[N],id[N];
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int n,m;
struct node{
    int l,r;
    ll sum;
}tr[N<<2];
struct dd{
    int a,b,c;
};
vector<dd> num[N];
void build(int u,int l,int r){
    if(l==r){
        tr[u]={l,r};
    }
    else{
        tr[u]={l,r};
        int mid=l+r>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
    }
}
void dfs(int u,int fa){
    depth[u]=depth[fa]+1;
    f[u][0]=fa;
    int i;
    sz[u]=1;
    for(i=1;i<=22;i++){
        f[u][i]=f[f[u][i-1]][i-1];
    }
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa)
            continue;
        ff[j]=u;
        dfs(j,u);
        sz[u]+=sz[j];
        if(sz[j]>sz[son[u]]){
            son[u]=j;
        }
    }
}
void dfs1(int u,int x){
    dfn[u]=++times;
    id[times]=u;
    top[u]=x;
    if(!son[u])
        return;
    dfs1(son[u],x);
    int i;
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==ff[u]||j==son[u])
            continue;
        dfs1(j,j);
    }
}
int lca(int a,int b){
    if(depth[a]<depth[b])
        swap(a,b);
    int i;
    for(i=21;i>=0;i--){
        if(depth[f[a][i]]>=depth[b]){
            a=f[a][i];
        }
    }
    if(a==b)
        return a;
    for(i=21;i>=0;i--){
        if(f[a][i]!=f[b][i]){
            a=f[a][i];
            b=f[b][i];
        }
    }
    return f[a][0];
}
ll sum[N],dp[N];
void pushup(int u){
    tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void modify(int u,int l,int x){
    if(tr[u].l==tr[u].r){
        tr[u].sum+=x;
        return ;
    }
    int mid=tr[u].l+tr[u].r>>1;
    if(l<=mid){
        modify(u<<1,l,x);
    }
    else{
        modify(u<<1|1,l,x);
    }
    pushup(u);
}
ll query(int u,int l,int r){
    if(tr[u].l>=l&&tr[u].r<=r){
        return tr[u].sum;
    }
    ll ans=0;
    int mid=tr[u].l+tr[u].r>>1;
    if(l<=mid)
        ans+=query(u<<1,l,r);
    if(r>mid)
        ans+=query(u<<1|1,l,r);
    return ans;
}
ll querypath(int x,int y){
    ll ans=0;
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]])
            swap(x,y);
        ans+=query(1,dfn[top[x]],dfn[x]);
        x=ff[top[x]];
    }
    if(depth[x]>depth[y])
        swap(x,y);
    ans+=query(1,dfn[x],dfn[y]);
    return ans;
}
void solve(int u,int fa){
    int i;
    ll tmp=0;
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa)
            continue;
        solve(j,u);
        tmp+=dp[j];
    }
    dp[u]=tmp;
    for(i=0;i<num[u].size();i++){
        auto t=num[u][i];
        dp[u]=max(dp[u],tmp+querypath(t.a,u)+querypath(t.b,u)+t.c);
    }
    modify(1,dfn[u],tmp-dp[u]);
}
int main(){
    ios::sync_with_stdio(false);
    int t;
    cin>>t;
    while(t--){
        int i;
        cin>>n>>m;
        idx=0;
        times=idx=0;
        for(i=0;i<=n;i++){
            h[i]=-1;
            num[i].clear();
            son[i]=0;
            dp[i]=0;
        }
        for(i=1;i<n;i++){
            int a,b;
            cin>>a>>b;
            add(a,b);
            add(b,a);
        }
        dfs(1,0);
        dfs1(1,1);
        build(1,1,n);
        for(i=1;i<=m;i++){
            int a,b,w;
            cin>>a>>b>>w;
            int p=lca(a,b);
            num[p].push_back({a,b,w});
        }
        solve(1,0);
        cout<<dp[1]<<endl;
    }
    return 0;
}
View Code

 

posted @ 2021-04-30 16:13  朝暮不思  阅读(44)  评论(0编辑  收藏  举报