【树】【积累】【题解】Count on a Tree

2020 Multi-University Training Contest 2 Count on a Tree II Striking Back

题意:

一棵含 \(n\) 个点的树,树上每个点都有颜色,m个操作,有两种:

\(1\,x\,y\) :将节点x的颜色设置为 \(y\)

\(2\,a\,b\,c\,d\) :定义 \(f(u,v)\) 表示链 \(u,v\) 上包含不同颜色的个数,询问是否 \(f(a,b)>f(c,d)\)

保证询问时有 \(f(a,b)\ge2f(c,d)\)\(f(c,d)\ge2f(a,b)\)

其中 \(1\le n\le500000,1\le m\le10000,1\le col_i\le n,1\le T\le4\)

官方题解:

有结论:\(k\)\([1,0]\) 的随机实数的最小值的期望为 \(\frac{1}{k+1}\) 。对于一个大小为 \(k\) 的集合,如果给每个元素随机一个正整数,那么多次采样得到的平均最小值越小就说明k的值越大

回到本题:进行 \(k\) 次采样,每次采样时对每种颜色随机一个正整数,令每个点的点权为其颜色对应的随机数,然后统计询问的树链上点权的最小值,将 \(k\) 次采样的结果相加以粗略比较两条树链的颜色数的大小,因为不要求精确值所以 \(k\) 取几十即可得到正确结果。

使用树链剖分+线段树的时间复杂度为 \(\mathcal{O(nk+mklog^2n)}\) ,使用全局平衡二叉树可以左到 \(\mathcal{O(nk+mklogn)}\)

个人代码:

树剖:1653MS

#pragma GCC optimize(2)
#include<algorithm>
#include<cstdio>
#include<ctime>
using namespace std;
typedef long long ll;
typedef unsigned int U;
const int mod=998244353;
const int inf=0x3f3f3f3f;
const int maxn=500005;
const int N=30;

char buf[1<<20],*P1=buf,*P2=buf;
#define gc() (P1==P2&&(P2=(P1=buf)+fread(buf,1,1<<20,stdin),P1==P2)?EOF:*P1++)
#define TT template<class T>inline
TT void read(T&x){
    x=0;register char c=gc();register bool f=0;
    while(c<48||c>57){f^=c=='-',c=gc();}
    while(47<c&&c<58)x=(x<<3)+(x<<1)+(c^48),c=gc();
    if(f)x=-x;
}

int i,j;
int head[maxn],to[maxn<<1],nxt[maxn<<1],ect;
inline void addedge(int u,int v){
    to[++ect]=v;nxt[ect]=head[u];head[u]=ect;
}
int siz[maxn],dep[maxn],fad[maxn],son[maxn];
int top[maxn],tid[maxn],rnk[maxn],cnt;
void dfs1(int u,int fa)
{
    dep[u]=dep[fa]+1,fad[u]=fa,siz[u]=1;
    for(int i=head[u];i;i=nxt[i])
        if(to[i]!=fa)
        {
            dfs1(to[i],u);
            siz[u]+=siz[to[i]];
            if(siz[to[i]]>siz[son[u]])son[u]=to[i];
        }
}
void dfs2(int u,int t)
{
    top[u]=t;rnk[tid[u]=++cnt]=u;
    if(!son[u])return;
    dfs2(son[u],t);
    for(int i=head[u];i;i=nxt[i])
        if(to[i]!=son[u]&&to[i]!=fad[u])
            dfs2(to[i],to[i]);
}
int n,col[maxn];
U rc[maxn][N],mi[maxn<<2][N];
void build(int k,int l,int r)
{
    if(l==r){
        int c=col[rnk[l]];
        for(i=0;i<N;i++)
            mi[k][i]=rc[c][i];
        return;
    }
    int mid=(l+r)>>1;
    build(k<<1,l,mid),build(k<<1|1,mid+1,r);
    for(i=0;i<N;i++)
        mi[k][i]=min(mi[k<<1][i],mi[k<<1|1][i]);
}
void update(int k,int l,int r,int x)
{
    if(l==r){
        int c=col[rnk[l]];
        for(i=0;i<N;i++)
            mi[k][i]=rc[c][i];
        return ;
    }
    int mid=(l+r)>>1;
    if(x<=mid)update(k<<1,l,mid,x);
    else update(k<<1|1,mid+1,r,x);
    for(i=0;i<N;i++)
        mi[k][i]=min(mi[k<<1][i],mi[k<<1|1][i]);
}
U tmp[N];
void query(int k,int l,int r,int L,int R)
{
    if(L<=l&&r<=R){
        for(i=0;i<N;i++)tmp[i]=min(tmp[i],mi[k][i]);
        return;
    }
    int mid=(l+r)>>1;
    if(L<=mid)query(k<<1,l,mid,L,R);
    if(R>mid)query(k<<1|1,mid+1,r,L,R);
}
inline ll trquery(int u,int v)
{
    for(i=0;i<N;i++)tmp[i]=inf;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        query(1,1,n,tid[top[u]],tid[u]);
        u=fad[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    query(1,1,n,tid[u],tid[v]);
    ll sum=0;
    for(i=0;i<N;i++)sum+=tmp[i];
    return sum;
}
mt19937 mt(time(0));
int main()
{
//    freopen("C:\\Users\\admin\\Desktop\\tmp\\1003.in","r",stdin);
//    freopen("C:\\Users\\admin\\Desktop\\tmp\\1003out.txt","w",stdout);
    for(i=0;i<maxn;i++)
        for(j=0;j<N;j++)rc[i][j]=mt();
    int T,m,ans,op,x,y,z,w;ll sum1,sum2;
    read(T);
    while(T--)
    {
        ans=0;
        read(n);read(m);
        for(i=ect=cnt=0;i<=n;i++)head[i]=fad[i]=dep[i]=son[i]=siz[i]=0;
        for(i=1;i<=n;i++)read(col[i]);
        for(i=1;i<n;i++)
        {
            read(x);read(y);
            addedge(x,y);addedge(y,x);
        }
        dfs1(1,0);dfs2(1,1);
        build(1,1,n);
        while(m--)
        {
            read(op);read(x);read(y);
            x^=ans,y^=ans;
            if(op==1){
                col[x]=y;
                update(1,1,n,tid[x]);
            }else{
                read(z);read(w);
                z^=ans,w^=ans;
                sum1=trquery(x,y),sum2=trquery(z,w);
                if(sum1<sum2){
                    puts("Yes");ans++;
                }
                else puts("No");
            }
        }
    }
}

全局平衡二叉树:超时(处理不好 或者 常数过大

//#pragma GCC optimize(2)
#include<algorithm>
#include<cstdio>
#include<ctime>
#define min(a,b) (a)<(b)?(a):(b) 
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
const int maxn=5e5+5;

char buf[1<<20],*P1=buf,*P2=buf;
#define gc() (P1==P2&&(P2=(P1=buf)+fread(buf,1,1<<20,stdin),P1==P2)?EOF:*P1++)
#define TT template<class T>inline
TT void read(T&x){
    x=0;register char c=gc();register bool f=0;
    while(c<48||c>57){f^=c=='-',c=gc();}
    while(47<c&&c<58)x=(x<<3)+(x<<1)+(c^48),c=gc();
    if(f)x=-x;
}

int to[maxn<<1],nxt[maxn<<1];
int head[maxn],ecnt;
inline void addedge(int u,int v)
{
    to[++ecnt]=v;nxt[ecnt]=head[u];
    head[u]=ecnt;
}
int depth[maxn<<1],id[maxn],rid[maxn<<1],cnt,st[maxn<<1][25];
void dfs(int u,int f,int d)
{
    id[u]=++cnt;rid[cnt]=u;depth[cnt]=d;
    for(int i=head[u];i;i=nxt[i])
        if(to[i]!=f){
            dfs(to[i],u,d+1);
            rid[++cnt]=u;depth[cnt]=d;
        }
}
int lg[maxn<<1];
void init()
{
    lg[0]=-1;
    for(int i=1;i<=cnt;i++)lg[i]=lg[i>>1]+1;
    for(int i=1;i<=cnt;i++)st[i][0]=i;
    for(int j=1;(1<<j)<=cnt;j++)
        for(int i=1;i+(1<<j)-1<=cnt;i++)
            st[i][j]=depth[st[i][j-1]]<depth[st[i+(1<<j-1)][j-1]]?
                     st[i][j-1]:
                     st[i+(1<<j-1)][j-1];
}
inline int lca(int u,int v)
{
    if(id[u]>id[v])swap(u,v);
    int s=id[u],t=id[v],len=lg[t-s+1];
    return depth[st[s][len]]<depth[st[t-(1<<len)+1][len]]?rid[st[s][len]]:rid[st[t-(1<<len)+1][len]];
}
inline int dis(int u,int v){
    return depth[id[u]]+depth[id[v]]-(depth[id[lca(u,v)]]<<1);
}

//gbbt
int siz[maxn],lsiz[maxn],son[maxn];
void dfs1(int u,int f)//处理son和lsiz 的信息
{
    siz[u]=1;
    for(int i=head[u];i;i=nxt[i])
        if(to[i]!=f){
            dfs1(to[i],u);
            siz[u]+=siz[to[i]];
            if(siz[to[i]]>siz[son[u]])son[u]=to[i];
        }
    lsiz[u]=siz[u]-siz[son[u]];
}
int col[maxn];
int tfa[maxn],ch[maxn][2];
const int N=30;
unsigned int mi[maxn][N],val[maxn][N];
inline void update(int u){
    for(int i=0;i<N;i++){
        mi[u][i]=val[col[u]][i];
        if(ch[u][0])mi[u][i]=min(mi[u][i],mi[ch[u][0]][i]);
        if(ch[u][1])mi[u][i]=min(mi[u][i],mi[ch[u][1]][i]);
    }
}
inline int getson(int x){return (ch[tfa[x]][1]==x)?1:((ch[tfa[x]][0]==x)?0:-1);}
void pushup(int u){ 
    update(u);
    if(getson(u)!=-1)pushup(tfa[u]);
}
int si[maxn],tot;ll vsi[maxn];
int dep[maxn];bool vis[maxn];
int sbuild(int l,int r,int d)
{
    if(l>r)return 0;
    int mid=lower_bound(vsi+l,vsi+r+1,vsi[r]+vsi[l-1]>>1)-vsi;
    int u=si[mid];dep[u]=d;
    if(ch[u][0]=sbuild(l,mid-1,d+1))tfa[ch[u][0]]=u;
    if(ch[u][1]=sbuild(mid+1,r,d+1))tfa[ch[u][1]]=u;
    update(u);
    return u;
}
int build(int u,int d)
{
    int pos;
    for(pos=u,tot=0;pos;pos=son[pos]){
        si[++tot]=pos;vsi[tot]=vsi[tot-1]+lsiz[pos];
        vis[pos]=1;
    }
    int rt=sbuild(1,tot,d);
    for(pos=u;pos;pos=son[pos])
        for(int i=head[pos];i;i=nxt[i])
            if(!vis[to[i]])
                tfa[build(to[i],dep[pos]+1)]=pos;
    return rt;
}
inline bool online(int u,int v,int po){
    return dis(u,po)+dis(v,po)==dis(u,v);
}
unsigned int tmp[N];
ll cal(int u,int v)
{
    int x,pu=u,pv=v;bool juu=1,juv=1,fju;
    for(int i=0;i<N;i++)tmp[i]=min(val[col[u]][i],val[col[v]][i]);
    while(u!=v)
    {
        if(dep[u]<dep[v])swap(u,v),swap(juu,juv);//dep u > dep v
        x=getson(u);
        fju=online(pu,pv,tfa[u]);
        if(!fju){
            if(juu&&ch[u][x]){
                for(int i=0;i<N;i++)
                    tmp[i]=min(tmp[i],mi[ch[u][x]][i]);
            }
        }else{
            if(juu&&x==-1&&ch[u][0]){
                for(int i=0;i<N;i++)
                    tmp[i]=min(tmp[i],mi[ch[u][0]][i]);
            }
            else if(juu&&x!=-1&&ch[u][x^1]){
                x^=1;
                for(int i=0;i<N;i++)
                    tmp[i]=min(tmp[i],mi[ch[u][x]][i]);
            }
            for(int i=0;i<N;i++)
                tmp[i]=min(tmp[i],val[col[tfa[u]]][i]);
        }
        juu=fju;u=tfa[u];
    }
    ll sum=0;
    for(int i=0;i<N;i++)sum+=tmp[i];
    return sum;
}
mt19937 mt(time(0));
int main()
{
//    freopen("C:\\Users\\admin\\Desktop\\tmp\\1003.in","r",stdin);
//    freopen("C:\\Users\\admin\\Desktop\\tmp\\1003out.txt","w",stdout);
    for(int i=0;i<maxn;i++)
        for(int j=0;j<N;j++)val[i][j]=mt();
    int T;
    read(T);
    while(T--)
    {
        int n,m,u,v,rt,lans=0;
        read(n);read(m);
        for(int i=tot=ecnt=cnt=0;i<=n;i++)
            vis[i]=vsi[i]=head[i]=dep[i]=tfa[i]=son[i]=0;
        for(int i=1;i<=n;i++)read(col[i]);
        for(int i=1;i<n;i++)
        {
            read(u);read(v);
            addedge(u,v);addedge(v,u);
        }
        dfs(1,0,0);init();dfs1(1,0);build(1,1);
        int op,x,y;ll sum1,sum2;
        while(m--)
        {
            read(op);read(x);read(y);
            if(op==1){
                x^=lans;y^=lans;
                col[x]=y;pushup(x);
            }else if(op==2){
                read(u);read(v);
                x^=lans;y^=lans;u^=lans;v^=lans;
                sum1=cal(x,y);
                sum2=cal(u,v);
                if(sum1<sum2)puts("Yes"),lans++;
                else puts("No");
            }
        }
    }
}
posted @ 2020-08-07 21:32  草丛怪  阅读(202)  评论(0编辑  收藏  举报