luogu 5311 [Ynoi2011]D1T3 动态点分治+树状数组

我这份代码已经奇怪到一定程度了~

洛谷上一直 $TLE$,但是本地造了几个数据都过了.

简单说一下题解:

先建出来点分树.
对于每一个询问,在点分树中尽可能向上跳祖先,看是否能够处理这个询问.
找到最高点的好处就是该点的询问可以全部由那个祖先来统计.
因为祖先到 $x$ 是合法的,而那个祖先会统计子树里所有的点,当然也包括 $x$ 的所有子树.
假设现在枚举到点 $x$ ,并处理 $x$ 上面的所有询问.
$x$ 子树中的每一个点都可以用一个三元组来表示:$(l,r,c)$ 代表该点到 $x$ 路径的值域在 $[l,r],$颜色为 $c.$
那么对于一个询问,就是查询所有在 $[l_{q},r_{q}]$ 中 $c$ 的不同种类.
我们先将 $l$ 从大到小排序,依次处理点和询问.
如果有多个点,那么显然 $r$ 小的优先级会更高,即后加入的同颜色的点如果 $r$ 更小就直接替换.
这么做就能保证所有颜色的点在当前局面只出现一次.
现在的问题就是统计 $(l,r,l,r)$ 这个矩形内点的数量.
对于这个问题,可以用 $O(logn)$ 的树状数组来进行数点,总时间复杂度为 $O(nlog^2n)$. 

Code: 

 

#include <bits/stdc++.h>  
#define N 1000005    
#define inf 100001
#define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout)     
using namespace std;       
int edges,n,flag;
int hd[N],to[N<<1],nex[N<<1],val[N],rt[N<<2]; 
void add(int u,int v)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
} 
namespace BIT 
{  
    int C[N]; 
    int lowbit(int t) 
    {
        return t&(-t);  
    } 
    void update(int x,int v) 
    {
        for(;x<N;x+=lowbit(x)) C[x]+=v; 
    } 
    int query(int x) 
    {
        int re=0; 
        for(;x>0;x-=lowbit(x)) re+=C[x]; 
        return re;    
    }
    void re(int x) 
    {
        for(;x<N;x+=lowbit(x)) C[x]=0;      
    }
}; 
struct Node
{
    int l,r,x;    
    Node(int l=0,int r=0,int x=0):l(l),r(r),x(x){}    
};
vector<Node>e[N];    
struct Point
{
    int l,r,id,val;
    Point(int l=0,int r=0,int id=0,int val=0):l(l),r(r),id(id),val(val){}  
};      
vector<Point>F[N];   
bool cmp(Point a,Point b)
{
    return a.l==b.l?a.id>b.id:a.l>b.l;                  
}
int root,sn;
int mx[N],size[N],vis[N],Fa[N],answer[N],lsty[N],lstx[N];      
void dfs(int u,int ff)
{
    size[u]=1;
    for(int i=hd[u];i;i=nex[i])
        if(to[i]!=ff&&!vis[to[i]])
            dfs(to[i],u),size[u]+=size[to[i]];    
}
void getroot(int u,int ff)
{
    size[u]=1,mx[u]=0; 
    for(int i=hd[u];i;i=nex[i])
        if(to[i]!=ff&&!vis[to[i]])
            getroot(to[i],u),size[u]+=size[to[i]],mx[u]=max(mx[u],size[to[i]]);
    mx[u]=max(mx[u],sn-size[u]);   
    if(mx[u]<mx[root]) root=u;
}
void calc(int u,int ff,int Min,int Max,int rt)
{
    Min=min(Min,u),Max=max(Max,u);
    F[rt].push_back(Point(Min,Max,u,val[u])), e[u].push_back(Node(Min,Max,rt));    
    for(int i=hd[u];i;i=nex[i])
        if(!vis[to[i]]&&to[i]!=ff)
            calc(to[i],u,Min,Max,rt);
}
void prepare(int u)
{
    vis[u]=1;
    calc(u,0,u,u,u);
    for(int i=hd[u];i;i=nex[i])
        if(!vis[to[i]])
            dfs(to[i],u),sn=size[to[i]],root=0,getroot(to[i],u),Fa[root]=u,prepare(root);        
}
void Push(int u,int l,int r,int id)
{
    for(int i=0;i<(int)e[u].size();++i)
    {
        if(e[u][i].l>=l&&e[u][i].r<=r) 
        {
            F[e[u][i].x].push_back(Point(l,r,-1,id));
            break;
        }
    }
}         
void solve(int u)
{  
    int i,j;
    sort(F[u].begin(),F[u].end(),cmp);  
    for(i=0;i<(int)F[u].size();++i)
    {
        Point p=F[u][i];     
        if(p.id==-1)
            answer[p.val]=BIT::query(p.r);        
        else
        {  
            if(!lsty[p.val]||p.r<=lsty[p.val])
            {            
                if(lsty[p.val]) 
                {
                    BIT::update(lsty[p.val],-1);                          
                } 
                BIT::update(p.r,1);      
                lstx[p.val]=p.l, lsty[p.val]=p.r;                 
            }
        }
    }  
    for(i=0;i<(int)F[u].size();++i)
        if(F[u][i].id!=-1)
        {   
            if(lsty[F[u][i].val]) BIT::re(lsty[F[u][i].val]);    
            lstx[F[u][i].val]=lsty[F[u][i].val]=0;    
        }
}
int main()
{
    int i,j,m;
    // setIO("input");
    scanf("%d%d",&n,&m); 
    for(i=1;i<=n;++i) scanf("%d",&val[i]); 
    for(i=1;i<n;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    }
    root=0,mx[0]=sn=n,getroot(1,0),prepare(root);   
    for(i=1;i<=m;++i)
    {                 
        int l,r,x;
        scanf("%d%d%d",&l,&r,&x),Push(x,l,r,i);                  
    }            
    for(i=1;i<=n;++i) if(F[i].size()) solve(i);      
    for(i=1;i<=m;++i) printf("%d\n",answer[i]);                                  
    return 0;
}  

 

  

 

  

posted @ 2019-09-05 16:57  EM-LGH  阅读(231)  评论(0编辑  收藏  举报