SP10707 COT2 - Count on a tree II 树上莫队

题目可在vj上提交:https://vjudge.net/problem/SPOJ-COT2

 

题意翻译

  • 给定 n 个结点的树,每个结点有一种颜色。
  • m 次询问,每次询问给出 u,v,回答 u,v 之间的路径上的结点的不同颜色数。
  • 1<=n<=4e4, 1<=m<=1e5

输入输出样例

输入 #1
8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8
输出 #1
4
4

 

题解:

树上莫队需要用到欧拉序

欧拉序就是第一次遇见x结点的时候把它放到数组里面,最后x子节点遍历完之后再把x放入数组

void dfs(int x)
{
    ord[++len]=x;  //第一次遇见的时候放入数组
    first[x]=len;  //first和second数组是记录某个数在欧拉序中第一次出现和第二次出现的位置
    for(int i=head[x];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa[x][0]) continue;
        depth[to]=depth[x]+1;  //记录节点深度,深度从1开始
        fa[to][0]=x;
        for(int j=1;(1<<j)<=depth[to];++j)  //这一部分是为了求lca做的处理
        {
            fa[to][j]=fa[fa[to][j-1]][j-1];
        }
        dfs(to);
    }
    ord[++len]=x;  //子节点遍历完之后再放入数组
    second[x]=len;
}

 

比如你需要找1->7这条路径上的所有点,就可以通过欧拉序区间[first[1],first[7]]这个区间内的值就对应的1->7这个路径

但是有些路径不是,看下图:

对于1->10你会发现,这个区间内包含了很多实际上用不到的数。

其实我们只需要把这个区间内出现两次的数删掉,剩下的就是1->10这个路径上遇到的点。

1 2 4 7 7 4 5 5 2 3 6 8 9 9 10这个序列删除4、7、5、2、9就变成了1,3,6,8,10正好就是原路经

至于为什么,你可以想一想欧拉序是怎么构成的,如果一个数出现了两次,那就证明这个数是1->10路径上的分支

 

但是对于路径2->6,我们使用上面的方法你会发现获得的序列不满足我们的实际需求,正确操作是找到欧拉序的区间[last[2],start[6]]

这一部分区间对应欧拉序为:2,3,6。少一个1,为什么少一个1?

因为1是它们的最近父节点,你找的2,6分别在1的两个分支上,所以欧拉序这个区间内肯定不包含1。那我们只需要加上2和6的最近父节点就可以了

这个找最近父节点可以使用lca,不会的可以看一下:lca讲解 && 例题 HDU - 4547 

我们这里使用lca的第三种方式,使用倍增lca

 

这样分析之后,我们就发现树上莫队就和普通莫队差不多了

 

代码:

#include <map>
#include <set>
#include <list>
#include <queue>
#include <deque>
#include <cmath>
#include <stack>
#include <vector>
#include <bitset>
#include <cstdio>
#include <string>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 2e5+10;
const int INF = 0x3f3f3f3f;
const double PI = 3.1415926;
const long long N = 1000006;
const double eps = 1e-10;
typedef long long ll;
#define mem(A, B) memset(A, B, sizeof(A))
#define lson rt<<1 , L, mid
#define rson rt<<1|1 , mid + 1, R
#define ls rt<<1
#define rs rt<<1|1
#define SIS std::ios::sync_with_stdiget_mod_new(z-x)o(false), cin.tie(0), cout.tie(0)
#define pll pair<long long, long long>
#define lowbit(abcd) (abcd & (-abcd))
#define max(a, b) ((a > b) ? (a) : (b))
#define min(a, b) ((a < b) ? (a) : (b))
inline int read() {  //读取整数
    int res = 0;
    char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)) res = (res << 1) + (res << 3) + (c ^ 48), c = getchar();
    return res;
}
int arr[maxn],cnt[maxn],first[maxn],second[maxn],ans[maxn],belong[maxn];
int cnte,inp[maxn],vis[maxn],sizes,new_size,len,now,n,m; //莫队相关
int ord[maxn],val[maxn],head[maxn],depth[maxn],fa[maxn][30];
//ord保存的是欧拉序
struct edge{
    int to,next;
}e[maxn];
struct Node{
    int l,r,lca,id;
}node[maxn];
bool cmp(Node a,Node b)
{
    return (belong[a.l]^belong[b.l])?(belong[a.l]<belong[b.l]):((belong[a.l]&1)?a.r<b.r:a.r>b.r);
}
void add_edge(int x,int y)
{
    e[++cnte]=(edge){y,head[x]};
    head[x]=cnte;
    e[++cnte]=(edge){x,head[y]};
    head[y]=cnte;
}
void dfs(int x)
{
    ord[++len]=x;  //第一次遇见的时候放入数组
    first[x]=len;  //first和second数组是记录某个数在欧拉序中第一次出现和第二次出现的位置
    for(int i=head[x];i;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa[x][0]) continue;
        depth[to]=depth[x]+1;  //记录节点深度,深度从1开始
        fa[to][0]=x;
        for(int j=1;(1<<j)<=depth[to];++j)  //这一部分是为了求lca做的处理
        {
            fa[to][j]=fa[fa[to][j-1]][j-1];
        }
        dfs(to);
    }
    ord[++len]=x;  //子节点遍历完之后再放入数组
    second[x]=len;
}
int get_lca(int u,int v) //使用倍增lca
{
    if(depth[u] < depth[v])
        swap(u, v);
    for(int i = 20; i + 1; --i)
        if(depth[u] - (1 << i) >= depth[v])
            u = fa[u][i];
    if(u == v)
        return u;
    for(int i = 20; i + 1; --i)
        if(fa[u][i] != fa[v][i])
            u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}
void work(int pos)  //因为欧拉序中一个数出现两次就要删除,所以使用vis数组来标记下
{
    vis[pos] ? now-=!--cnt[val[pos]] : now += !cnt[val[pos]]++;
    vis[pos] ^= 1;
}
int main()
{
    //scanf("%d%d",&n,&m);
    n=read();
    m=read();
    for(int i=1;i<=n;++i)
    {
        //scanf("%d",&val[i]);
        inp[i]=val[i]=read();
    }
    sort(inp+1,inp+1+n);
    int tot=unique(inp+1,inp+1+n)-inp-1;  //去重后有多少元素
    for(int i=1;i<=n;++i)
    {
        //对去重后的数组进行二分
        val[i]=lower_bound(inp+1,inp+1+tot,val[i])-inp;
        //printf("%d\n",val[i]);
    }
    for(int i=1;i<n;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add_edge(x,y);
    }
    depth[1]=1;
    dfs(1);
    sizes=sqrt(len);
    new_size=ceil((double)len/sizes);
    for(int i=1;i<=new_size;++i)
    {
        for(int j=(i-1)*sizes+1;j<=i*sizes;++j)
        {
            belong[j]=i;
        }
    }
    for(int i=1;i<=m;++i)
    {
        int x,y,z;
        x=read();
        y=read();
        //scanf("%d%d",&x,&y);
        z=get_lca(x,y);
        if(first[x]>first[y]) swap(x,y);
        if(x==z)  //如果其中一个节点是它俩的最近父节点,那就采用第一种方法
        {
            node[i].l=first[x];
            node[i].r=first[y];
        } 
        else  //否则就要最后加一个父节点
        {
            node[i].l=second[x];
            node[i].r=first[y];
            node[i].lca=z;
        }
        node[i].id=i;
    }
    sort(node+1,node+1+m,cmp);
    int l=1,r=0;
    for(int i=1;i<=m;++i)
    {
        int start=node[i].l,last=node[i].r,lca=node[i].lca;
        while(l<start) work(ord[l++]);
        while(l>start) work(ord[--l]);
        while(r>last) work(ord[r--]);
        while(r<last) work(ord[++r]);
        if(lca) work(lca);
        ans[node[i].id]=now;
        if(lca) work(lca);
    }
    for(int i=1;i<=m;++i)
        printf("%d\n",ans[i]);
    return 0;
}

 

posted @ 2020-09-24 11:19  kongbursi  阅读(119)  评论(0编辑  收藏  举报