树上莫队

在之前我已经发过了普通莫队的博客了。
传送门
打了几道莫队的裸题后,我就学了一下树上莫队。


例题

这题的英文超好懂,我相信你的英语水平。
但我还是解释一下吧。
题目大意:给你一棵N个点的带权的树,有M个询问,询问两点之间不同的权值个数。
其中N10000,M400000

往区间的方向思考一下

处理树上的信息,有一个传统的套路,就是将树转化成序列。
什么dfs序,欧拉序,括号序。
那么咋转化成序列?
思考一下~

树上莫队

树上莫队用的是括号序
不知道这是啥玩意?这就是dfs过程中,入栈和出栈各记一次的顺序。显然长度是N2的。

做法

对于uv之间的路径,假设uin<vin
分两种情况:
1. LCA(u,v)=uin时,可以询问区间[uin,vin]
2. 否则,询问区间[uout,vin],并且另外加上LCA(u,v)的贡献。
询问区间是询问区间中只出现过一次的节点中的答案。

Why?

可以试着随手画一棵树,手玩一下~
先说第一种情况:
uv的路径是一条链。意味着链上的每个后代都被祖先的入序和出序夹在中间。
那么这样询问,显然链上的每一个点都必定只出现一次。
而那些不在链上但在区间中的点,必然是出现了两次的。
再说第二种情况:
在这个区间内,显然出现了两遍的点都是不在这条路径上的。
这个区间中,可以理解成,从u中离开,继续对这棵树遍历,其中,由于它的祖先在前面已经出现过一次了,所以在后面只可能出现一次,而那些没有用的节点,就会出现两次。当再次回到LCA时,继续遍历,v的祖先只会在这个区间中出现一次,无关的点出现两次。
但是这样显然没有LCA,所以LCA另外算。
理解比较抽象,注意思考……
这就是dfs搞出来的括号序的应用。具体为什么,可以在dfs中理解一下。dfs的性质很神奇。

后面的事……

根据这些变成一条条区间询问,然后和正常的莫队一样就行了。


代码

这题我没有AC。
原因是我并没有SPOJ的账号,并且注册不了。
就当模板用吧。

using namespace std;
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define MAXN 10000
#define MAXM 400000
int n,m;
int col[MAXN+1];
int *p[MAXN+1];
bool cmpp(int *x,int *y){
    return *x<*y;
}
struct EDGE{
    int to;
    EDGE *las;  
} e[MAXN*2+1];
int ne;
EDGE *last[MAXN+1];
void insert_edge(int u,int v){
    e[++ne]={v,last[u]};
    last[u]=e+ne;
}
int in[MAXN+1],out[MAXN+1],nowdfn;
int dy[MAXN*2+1];
int unit,be[MAXN*2+1];
int dep[MAXN+1];
int fa[MAXN+1][15];
void init(int);
int LCA(int,int);
struct Operation{
    int time,l,r,another;//如果要额外算LCA,another即为LCA的颜色,否则为0
} o[MAXM+1];
bool cmp(const Operation &x,const Operation &y){
    return be[x.l]<be[y.l] || be[x.l]==be[y.l] && x.r<y.r;
}
int num[MAXN+1];//表示某个颜色的出现次数
int gx[MAXN+1];//贡献,表示是否只出现一次(其实可以用bool数组)
int ans[MAXM+1];
int main(){
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;++i)
        scanf("%d",&col[i]),p[i]=col+i;
    sort(p+1,p+n+1,cmpp);
    for (int i=1,k=0,last=-2147483648;i<=n;++i){//离散化
        if (last!=*p[i])
            ++k,last=*p[i];
        *p[i]=k;
    }
    for (int i=1;i<n;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        insert_edge(u,v),insert_edge(v,u);
    }
    init(1);
    unit=sqrt(nowdfn);
    for (int i=1;i<=nowdfn;++i)
        be[i]=(i-1)/unit+1;
    for (int i=1;i<=m;++i){
        o[i].time=i;
        int u,v;
        scanf("%d%d",&u,&v);
        if (in[u]>in[v])
            swap(u,v);
        if (out[v]<out[u]){
            o[i].l=in[u];
            o[i].r=in[v];
            o[i].another=0;
        }
        else{
            o[i].l=out[u];
            o[i].r=in[v];
            o[i].another=col[LCA(u,v)];
        }
    }
    sort(o+1,o+m+1,cmp);
    int nowl=1,nowr=0,nowans=0;
    for (int i=1;i<=m;++i){
        while (nowr<o[i].r){
            nowr++;
            int lasnum=num[col[dy[nowr]]];
            num[col[dy[nowr]]]+=(gx[dy[nowr]]^1)-gx[dy[nowr]];
            gx[dy[nowr]]^=1;
            if (!lasnum && num[col[dy[nowr]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowr]]])
                nowans--;
        }
        while (nowl>o[i].l){
            nowl--;
            int lasnum=num[col[dy[nowl]]];
            num[col[dy[nowl]]]+=(gx[dy[nowl]]^1)-gx[dy[nowl]];
            gx[dy[nowl]]^=1;
            if (!lasnum && num[col[dy[nowl]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowl]]])
                nowans--;
        }
        while (nowr>o[i].r){
            int lasnum=num[col[dy[nowr]]];
            num[col[dy[nowr]]]+=(gx[dy[nowr]]^1)-gx[dy[nowr]];
            gx[dy[nowr]]^=1;
            if (!lasnum && num[col[dy[nowr]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowr]]])
                nowans--;
            nowr--;
        }
        while (nowl<o[i].l){
            int lasnum=num[col[dy[nowl]]];
            num[col[dy[nowl]]]+=(gx[dy[nowl]]^1)-gx[dy[nowl]];
            gx[dy[nowl]]^=1;
            if (!lasnum && num[col[dy[nowl]]])
                nowans++;
            else if (lasnum && !num[col[dy[nowl]]])
                nowans--;
            nowl++;
        }   
        if (o[i].another)
            ans[o[i].time]=nowans+!num[o[i].another];
        else
            ans[o[i].time]=nowans;
    }
    for (int i=1;i<=m;++i)
        printf("%d\n",ans[i]);
    return 0;
}
void init(int x){
    in[x]=++nowdfn;
    dy[nowdfn]=x;
    dep[x]=dep[fa[x][0]]+1;
    for (int i=1;1<<i<dep[x];++i)
        fa[x][i]=fa[fa[x][i-1]][i-1];
    for (EDGE *ei=last[x];ei;ei=ei->las)
        if (ei->to!=fa[x][0])
            fa[ei->to][0]=x,init(ei->to);
    out[x]=++nowdfn;
    dy[nowdfn]=x;
}
int LCA(int u,int v){//这是利用倍增来求LCA的
    if (dep[u]<dep[v])
        swap(u,v);
    for (int k=dep[u]-dep[v],i=0;k;k>>=1,++i)
        u=fa[u][i];
    if (u==v)
        return u;
    for (int i=log2(dep[u]);i>=0;--i)
        if (fa[u][i]!=fa[v][i]){
            u=fa[u][i];
            v=fa[v][i];
        }
    return fa[u][0];
}
posted @ 2018-09-07 21:02  jz_597  阅读(133)  评论(0编辑  收藏  举报