BZOJ 4860: [BeiJing2017]树的难题 点分治+线段树

这个思路还是非常巧妙的. 

困难在于我们需要同时维护以 $x$ 为分治中心,延伸出颜色相同/不同的最大值.      

不同的话直接将权和相加,相同的话还需要减掉重复部分,这就比较难办.  

但是我们发现,当以 $x$ 为分治中心时,$x$ 每一个儿子为根的子树的延伸颜色都是相同的. 

所以我们可以将每一个点的所有儿子按照延伸颜色排序,然后维护两颗线段树:相同与不同. 

当我们在点分治时处理到下一个儿子,而下一个儿子与当前儿子颜色不同时,用线段树合并的方式将相同线段树合并到不同即可. 

code:

#include <cstdio> 
#include <string>  
#include <vector>
#include <cstring> 
#include <algorithm>   
#define N 200006 
#define ll long long  
#define inf 1000000009 
using namespace std; 
void setIO(string s) 
{
    freopen((s+".in").c_str(),"r",stdin);  
    // freopen((s+".out").c_str(),"w",stdout);  
}
int answer=-inf,mp=0; 
namespace seg
{   
    int tot;     
    struct node { int ls,rs,maxx; }s[N*50];      
    void clr()  
    { 
        for(int i=1;i<=tot;++i) s[i].ls=s[i].rs=0,s[i].maxx=-inf;   
        tot=0;  
    }     
    int merge(int x,int y) 
    {                               
        if(!x||!y) return x+y;     
        s[x].maxx=max(s[x].maxx,s[y].maxx);    
        s[x].ls=merge(s[x].ls,s[y].ls); 
        s[x].rs=merge(s[x].rs,s[y].rs);   
        return x;  
    }
    void update(int &x,int l,int r,int p,int v) 
    {
        if(!x) x=++tot,s[x].maxx=-inf;    
        s[x].maxx=max(s[x].maxx,v); 
        if(l==r) return;  
        int mid=(l+r)>>1;   
        if(p<=mid)  update(s[x].ls,l,mid,p,v); 
        else update(s[x].rs,mid+1,r,p,v);   
    }       
    int query(int x,int l,int r,int L,int R) 
    {
        if(!x||L>R) return -inf; 
        if(l>=L&&r<=R) return s[x].maxx;  
        int mid=(l+r)>>1,re=-inf;  
        if(L<=mid) re=max(re,query(s[x].ls,l,mid,L,R)); 
        if(R>mid)  re=max(re,query(s[x].rs,mid+1,r,L,R));     
        return re;   
    }
};  
int n,m,_min,_max,root,sn,tot;       
int val[N],size[N],mx[N],vis[N];     
struct Edge 
{
    int to,w;    
    Edge(int to=0,int w=0):to(to),w(w){}  
};  
struct Dis 
{ 
    int d,v; 
    Dis(int d=0,int v=0):d(d),v(v){}  
}A[N];  
vector<Edge>F[N];  
vector<Edge>G[N];  
bool cmp(Edge a,Edge b) { return a.w<b.w; }   
void getroot(int u,int ff) 
{      
    size[u]=1,mx[u]=0; 
    for(int i=0;i<G[u].size();++i) 
    {
        int v=G[u][i].to;   
        if(vis[v]||v==ff)  continue;       
        getroot(v,u),size[u]+=size[v];   
        mx[u]=max(mx[u],size[v]);                                 
    }
    mx[u]=max(mx[u],sn-size[u]); 
    if(mx[u]<mx[root]) root=u;  
} 
void dfs(int u,int ff,int d,int v,int pr) 
{     
    A[++tot]=Dis(d,v);             
    for(int i=0;i<G[u].size();++i) 
    {
        int y=G[u][i].to; 
        if(vis[y]||y==ff) continue;     
        dfs(y,u,d+1,v+val[G[u][i].w]*(G[u][i].w!=pr),G[u][i].w);                                        
    }
}
void solve(int u) 
{       
    int rt_diff=0,rt_same=0;    
    F[u].clear();  
    for(int i=0;i<G[u].size();++i) if(!vis[G[u][i].to]) F[u].push_back(G[u][i]);   
    for(int i=0;i<F[u].size();++i) 
    {
        int v=F[u][i].to;          
        tot=0; 
        dfs(v,u,1,val[F[u][i].w],F[u][i].w);                       
        for(int j=1;j<=tot;++j) 
        {   
            if(A[j].d>_max)  continue;     
            if(A[j].d>=_min&&A[j].d<=_max) 
            {    
                answer=max(answer,A[j].v),mp=max(mp,A[j].v);      
            }
            answer=max(answer,A[j].v+seg::query(rt_diff,1,n,max(1,_min-A[j].d),_max-A[j].d));       
            answer=max(answer,A[j].v+seg::query(rt_same,1,n,max(1,_min-A[j].d),_max-A[j].d)-val[F[u][i].w]);   
        }
        for(int j=1;j<=tot;++j) seg::update(rt_same,1,n,A[j].d,A[j].v);    
        if(i<F[u].size()-1&&F[u][i+1].w!=F[u][i].w)  
        {
            rt_diff=seg::merge(rt_diff,rt_same); 
            rt_same=0;  
        }                        
    }   
    seg::clr();              
    vis[u]=1;   
    for(int i=0;i<F[u].size();++i) 
    {
        int v=F[u][i].to;     
        sn=size[v],root=0,getroot(v,u),solve(root);  
    }
}
int main() 
{ 
    // setIO("input");  
    int i,j;   
    scanf("%d%d%d%d",&n,&m,&_min,&_max);   
    for(i=1;i<=m;++i)  scanf("%d",&val[i]);     
    for(i=1;i<n;++i) 
    {
        int x,y,z; 
        scanf("%d%d%d",&x,&y,&z);   
        G[x].push_back(Edge(y,z));   
        G[y].push_back(Edge(x,z));  
    }
    for(i=1;i<=n;++i) sort(G[i].begin(),G[i].end(),cmp);                  
    sn=n,mx[root=0]=N,getroot(1,0),solve(root),printf("%d\n",answer); 
    return 0; 
}

  

posted @ 2019-12-31 15:05  EM-LGH  阅读(152)  评论(0编辑  收藏  举报