SP10707 COT2 - Count on a tree II 树链不同的元素个数 树上莫队

题目

SP10707 COT2 - Count on a tree II (https://www.luogu.com.cn/problem/SP10707)

题意

给定 n 个结点的树,每个结点有一种颜色。
m 次询问,每次询问给出 u,v 回答 u,v 之间的路径上的结点的不同颜色数。
1≤n≤4×10^4
1≤m≤10^5

输入

8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8

输出

4
4

思路

我们考虑在序列上求区间不同元素的个数,用莫队就可以解决了。
那么在树链上呢?当然也可以,用dfs序就可以把树压平成一个序列。

裸的树上莫队。

其实和普通莫队上一样的,只不过我们要把树转化为线性结构,这就需要欧拉序,我们从根对这棵树进行dfsdfs,点进栈时记一个时间戳stst,出栈时再记一个时间戳eded,画个图理解一下

这棵树的欧拉序为(1,2,4,5,5,6,6,7,7,4,2,3,3)(1,2,4,5,5,6,6,7,7,4,2,3,3),那么每次询问的节点u,vu,v有两种情况

u在v的子树中(vv在uu的子树中同理),比如u=6,v=2我们拿出(st[2],st[6])这段区间(2,4,5,5,6) 55出现了两次,因为搜索的时候55不属于这条链,所以进去之后就出去了,而出现一次的都在这条链上,就都可以统计

u和v不在同一个子树中,比如u=5,v=3,这次拿出(ed[5],st[3])这段区间(5,6,6,7,7,4,2,3),要保证st[u]<st[v],出现两次的可以忽略,然而这次只统计了5,4,2,3所以最后再统计上lca就好了

至于如何忽略掉区间内出现了两次的点,这个很简单,我们多记录一个use[x],表示x这个点有没有被加入,每次处理的时候如果use[x]=0则需要添加节点;如果use[x]=1则需要删除节点,每次处理之后都对use[x]异或1就可以了

而欧拉序可以用树剖来求,lca也就求出来了,非常的方便

排序的话没有区别,可以普通排序,也可以奇偶性排序

因为st,ed的大小都是n,所以取块的大小时要用2n,而不是n

最后要注意的一点就是这个题权值比较大,需要离散化

#pragma GCC optimize(3, "Ofast", "inline")
//#pragma GCC target("avx,avx2,fma")
#pragma GCC optimization ("unroll-loops")

#include <bits/stdc++.h>
using namespace std;

#define LL long long
#define rint register int

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
char buf[1<<20],*p1=buf,*p2=buf;
inline int read() {
    int f=0,fu=1;
    char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-')
            fu=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        f=(f<<3)+(f<<1)+c-48;
        c=getchar();
    }
    return f*fu;
}

inline void print(LL x) {
  if (x < 0) { putchar('-'); x = -x; }
  if (x >= 10) print(x / 10);
  putchar(x % 10 + '0');
}
using namespace std;

struct Tree_lca {
    struct edge {
        int t, nex;
    } e[500010 << 1];
    int head[500010], tot;
    int depth[200050], fa[200050][22], lg[200050];
    int in[200050], out[200050], T;
    int id[200050];//记录dfs序对应的节点
    void init(int n) {
        memset(head, 0, sizeof(head));
        lg[1]=1;
        for(int i=2; i<=n; i++) {
            lg[i]=lg[i>>1]+1;
        }
        tot=T=0;
    }

    void add(int x, int y) {
        e[++tot].t = y;
        e[tot].nex = head[x];
        head[x] = tot;
    }

    void dfs(int now, int fath) {
        fa[now][0] = fath;
        in[now]=++T;
        id[T]=now;
        depth[now] = depth[fath] + 1;
        for(int i = 1; i <= lg[depth[now]]; ++i)
            fa[now][i] = fa[fa[now][i-1]][i-1];
        for(int i = head[now]; i; i = e[i].nex)
            if(e[i].t != fath)
                dfs(e[i].t, now);

        out[now]=++T;
        id[T]=now;
    }
    int LCA(int x, int y) {
        if(depth[x] < depth[y])
            swap(x, y);
        while(depth[x] > depth[y])
            x = fa[x][lg[depth[x]-depth[y]] - 1];
        if(x == y)
            return x;
        for(int k = lg[depth[x]] - 1; k >= 0; --k)
            if(fa[x][k] != fa[y][k])
                x = fa[x][k], y = fa[y][k];
        return fa[x][0];
    }

} T;

//下标1开始
struct LSH { //离散化
    int b[200050];
    int lsh(int a[], int n) { //得到离散化后不同元素的个数
        for(int i=1; i<=n; i++)
            b[i]=a[i];
        sort(b+1, b+n+1);
        int cnt=unique(b+1, b+n+1)-b-1;
        for(int i=1; i<=n; i++) {
            a[i]=lower_bound(b+1, b+cnt+1, a[i])-b;
        }
        return cnt;
    }
} Lsh;

struct que {
    int l, r, k, lca, i;
} q[200050];

int a[200050];
struct Lx_md {

    int ans[200050];
    int c[200050], ANS=0;
    int vis[200050];

    void del(int x) {
        ANS-=(--c[x]==0);
    }
    void calc(int x){
        x=T.id[x];

        if(!vis[x]){//判断这个节点是否这个区间,第二次出现是删除
            add(a[x]);
        }
        else{
            del(a[x]);
        }
        vis[x]^=1;
    }
    void add(int x) {
        ANS+=(++c[x]==1);
    }

    void getans(int Q) {
        sort(q+1, q+Q+1, [](que &a, que &b) {
            return a.k==b.k?(a.k&1)?a.r<b.r:a.r>b.r:a.k<b.k;
        });
        int L=1, R=0;
        for(int i=1; i<=Q; i++) {
            while(L<q[i].l) {
                calc(L);
                L++;
            }
            while(L>q[i].l) {
                L--;
                calc(L);
            }
            while(R<q[i].r) {
                R++;
                calc(R);
            }
            while(R>q[i].r) {
                calc(R);
                R--;
            }
            if(q[i].lca) calc(q[i].lca);
            ans[q[i].i]=ANS;
            if(q[i].lca) calc(q[i].lca);
        }
    }
} md;


int main() {

    int n=read(), m=read();
    for(int i=1; i<=n; i++) {
        a[i]=read();
    }
    Lsh.lsh(a, n);

    T.init(n);
    for(int i=1; i<n; i++) {
        int x, y;
        x=read(), y=read();
        T.add(x, y);
        T.add(y, x);
    }
    T.dfs(1, 0);

    int len=sqrt(2*n);//块大小
    for(int i=1; i<=m; i++) {
        int x, y;
        x=read(); y=read();

        int lca=T.LCA(x, y);
        if(T.in[x]>T.in[y]){
            swap(x, y);
        }
        
        //查询区间转化成dfs序
        if(lca==x) {
            q[i].l=T.in[x];
            q[i].r=T.in[y];
            q[i].k=q[i].l/len;
            q[i].lca=0;
            q[i].i=i;
        } else {
            q[i].l=T.out[x];
            q[i].r=T.in[y];
            q[i].k=q[i].l/len;
            q[i].lca=T.in[lca];
            q[i].i=i;
        }
    }
    md.getans(m);
    for(int i=1; i<=m; i++) {
        printf("%d\n", md.ans[i]);
    }


    return 0;
}
posted @   liweihang  阅读(146)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
Live2D
欢迎阅读『SP10707 COT2』
点击右上角即可分享
微信分享提示