Luogu6177 Count on a tree II/【模板】树分块

https://www.luogu.com.cn/problem/P6177

树分块

思路就是每隔\(S\)个点,取一个关键点,记录每两个关键点之间的信息,用\(bitset\)维护

然后每次询问一条链时,拆成以\(lca\)为上端节点的两条链

对于每条链,关键点之间的信息直接取就行,剩下的单独取(\(bitset\) \(or\) 运算)

如何取关键点?我们从深度最大的非关键点开始枚举,若其\(1-S\)级祖先中没有关键点,那么钦定其\(S\)级祖先为关键点

我用了树剖来计算\(1-S\)级祖先中是否有关键点

然后,就\(T\)了。。。

首先,调整块大小,在不\(MLE\)的情况下尽量接近\(\sqrt n\)

然后我在取关键点时,对于同一深度的点,取关键点顺序胡乱\(random\_shuffle\)一下(不\(rand\)\(T\)了,\(rand\)了也需要看脸,当然,没有开O2

最终\(S\)取了\(300\),不开\(O2\)最大点\(1.93s\)卡了过去(开\(O2\)最大点\(1.18s\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<ctime>
#include<bitset>
#define N 40005
#define S 300
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define lc ls(p),l,mid
#define rc rs(p),mid+1,r
#define bt bitset<N>
using namespace std;
int n,m,x,y,tot,mxdep,q[N],a[N],g[N],dep[N],f[N][22],col[N],w[N],fr[N],d[N << 1],nxt[N << 1];
int cnt,lst,rk[N],ht[N],z[N],st[N],ss[N],dfn[N],id[N],sz[N],son[N],T[N],fg[N];
bool tr[N << 2];
bool ke[N];
bt k[135],ans;
bt s[135][135];
inline int read()
{
    int s=0;
    char c=getchar();
    while (c<'0' || c>'9')
        c=getchar();
    while ('0'<=c && c<='9')
    {
        s=s*10+c-'0';
        c=getchar();
    }
    return s;
}
inline void add(int x,int y)
{
    tot++;
    d[tot]=y;
    nxt[tot]=fr[x];
    fr[x]=tot;
}
void dfs(int u)
{
    int mx=-1;
    sz[u]=1;
    mxdep=max(mxdep,dep[u]);
    g[dep[u]]++;
    for (int i=fr[u];i;i=nxt[i])
    {
        int v=d[i];
        if (v==f[u][0])
            continue;
        f[v][0]=u;
        dep[v]=dep[u]+1;
        dfs(v);
        sz[u]+=sz[v];
        if (sz[v]>mx)
        {
            mx=sz[v];
            son[u]=v;
        }
    }
}
void dfs2(int u,int tp)
{
    cnt++;
    dfn[cnt]=u;
    id[u]=cnt;
    T[u]=tp;
    if (!son[u])
        return;
    dfs2(son[u],tp);
    for (int i=fr[u];i;i=nxt[i])
    {
        int v=d[i];
        if (v==f[u][0] || v==son[u])
            continue;
        dfs2(v,v);
    }
}
inline int lca(int x,int y)
{
    if (dep[x]<dep[y])
        swap(x,y);
    for (int i=20;i>=0;i--)
        if (dep[f[x][i]]>=dep[y])
            x=f[x][i];
    if (x==y)
        return x;
    for (int i=20;i>=0;i--)
        if (f[x][i]!=f[y][i])
        {
            x=f[x][i];
            y=f[y][i];
        }
    return f[x][0];
}
void update(int p)
{
    if (!p)
        return;
    tr[p]=tr[ls(p)] | tr[rs(p)];
    update(p >> 1);
}
void build(int p,int l,int r)
{
    if (l==r)
    {
        ss[l]=p;
        return;
    }
    int mid=(l+r) >> 1;
    build(lc);
    build(rc);
    tr[p]=tr[ls(p)] | tr[rs(p)];
}
bool calc(int p,int l,int r,int x,int y)
{
    if (l==x && r==y)
        return tr[p];
    int mid=(l+r) >> 1;
    if (y<=mid)
        return calc(lc,x,y); else
    if (x>mid)
        return calc(rc,x,y); else
        {
            if (calc(lc,x,mid))
                return true;
            if (calc(rc,mid+1,y))
                return true;
            return false;
        }
}
bool check(int x,int y)
{
    while (T[x]!=T[y])
    {
        if (calc(1,1,n,id[T[x]],id[x]))
            return true;
        x=f[T[x]][0];
    }
    if (calc(1,1,n,id[y],id[x]))
        return true;
    return false;
}
inline void fl(int x,int y)
{
    while (!ke[x])
    {
        ans.set(col[x]);
        if (x==y)
            return;
        x=f[x][0];
    }
    int rx=x,px=-1;
    while (dep[ht[x]]>=dep[y])
        px=x,x=ht[x];
    if (~px)
        ans|=s[rk[rx]][rk[px]];
    ans.set(col[x]);
    while (x!=y)
    {
        x=f[x][0];
        ans.set(col[x]);
    }
}
void write(int x)
{
    if (x>9)
        write(x/10);
    putchar(x%10+'0');
}
int main()
{
	srand(time(NULL));
    n=read(),m=read();
    for (int i=1;i<=n;i++)
        col[i]=read(),w[i]=col[i];
    sort(w+1,w+n+1);
    int cc=unique(w+1,w+n+1)-w-1;
    for (int i=1;i<=n;i++)
        col[i]=lower_bound(w+1,w+cc+1,col[i])-w;
    for (int i=1;i<n;i++)
    {
        x=read(),y=read();
        add(x,y),add(y,x);
    }
    dep[1]=1;
    dfs(1);
    dfs2(1,1);
    for (int j=1;j<=20;j++)
        for (int i=1;i<=n;i++)
            f[i][j]=f[f[i][j-1]][j-1];
    for (int i=1;i<=mxdep;i++)
        g[i]+=g[i-1],fg[i]=g[i];
    for (int i=1;i<=n;i++)
        a[g[dep[i]]--]=i;
    for (int i=1;i<=mxdep;i++)
    	if (fg[i]>fg[i-1])
			random_shuffle(a+fg[i-1]+2,a+fg[i]+1);
    int p=S;
    for (int j=20;j>=0;j--)
        if (p>=(1 << j))
        {
            p-=(1 << j);
            q[++q[0]]=j;
        }
    build(1,1,n);
    for (int i=n;i;i--)
    {
        int u=a[i],v=u;
        if (ke[u])
            continue;
        if (dep[u]>S)
        {
            for (int j=1;j<=q[0];j++)
                v=f[v][q[j]];
            st[u]=v;
            if (!check(u,v))
            {
                z[++z[0]]=v;
                rk[v]=z[0];
                ke[v]=true;
                tr[ss[id[v]]]=true;
                update(ss[id[v]] >> 1);
            }
        }
    }
    for (int i=1;i<=z[0];i++)
    {
        int u=z[i];
        k[i].set(col[u]);
        u=f[u][0];
        while (!ke[u])
        {
            k[i].set(col[u]);
            u=f[u][0];
            if (!u)
                break;
        }
        ht[z[i]]=u;
    }
    for (int i=1;i<=z[0];i++)
    {
        s[i][i]=k[i];
        int u=z[i];
        while (ht[u])
        {
            s[i][rk[ht[u]]]=s[i][rk[u]] | k[rk[ht[u]]];
            u=ht[u];
        }
    }
    for (int i=1;i<=m;i++)
    {
        if (i!=1)
            ans.reset();
        x=read(),y=read();
        x^=lst;
        int kz=lca(x,y);
        fl(x,kz),fl(y,kz);
        lst=ans.count();
        write(lst),putchar('\n');
    }
    return 0;
}
posted @ 2020-09-05 18:11  GK0328  阅读(207)  评论(0编辑  收藏  举报