【福建集训】果树

题意

  有一棵 \(n\) 个点的树,每个点有一种 \([1,n]\) 内的整数颜色 \(c_i\),求有多少条不含重复颜色的路径(即路径上不能有两点颜色一样。当然,一个点也是一条合法路径)。
  \(n\le 10^5\)

题解

  我 只 会 点 分 治
  考试时先睡了一小时觉,然后再看这题,想了半天发现是个 sb 题,又感觉其他两题都做不来,于是自信码农起来。
  结果 zjr 不到 1h 就写完了这题 200 多行代码还 A 了,我 1h 写完后调了 3h,心态崩了……
  后来才发现这题标解根本不是点分治,大概是扫描线+线段树……写得比我还简单……我菜得没救了,,,
  顺便吐槽一下这种题暴力分好多啊,以后谁还自信当码农

扫描线+线段树

  对于颜色的限制 等价于 \(O(Tn)\)(本题中 \(T\le 20\))组形如 “\(u,v\) 这两点不能同时出现在路径上” 的限制。
  把一条路径 \((u,v)\) 对应到 \(n\times n\) 的二维平面上的点 \((dfn_u,dfn_v)\),则一条限制等价于二维平面上的 \(1\)\(2\) 个障碍矩形。
  最后就是求二维平面上有多少个点不被任何一个障碍矩形覆盖。扫描线+线段树做即可,复杂度 \(O(Tn\log n)\)
  我又忘了这种基础操作了,我在想什么啊,zbl

我的傻逼点分治

  太恶心了,代码都是我糊出来的,不想细讲……
  大概就是开两棵线段树,一棵记录以某个点为根的子树中有多少个点不能作为另一端点,另一棵记录某个点是否能作为另一端点。
  实际上也可能不要线段树?我说了代码是糊的,我自己都不知道咋蒙过的
  复杂度 \(O(Tn\log^2 n)\),理论上被标解吊打,实际上可能跑得比标解快……

#include<bits/stdc++.h>
#define ll long long
#define N 200005
#define inf 2147483647
using namespace std;
inline int read(){
    int x=0; bool f=1; char c=getchar();
    for(;!isdigit(c); c=getchar()) if(c=='-') f=0;
    for(; isdigit(c); c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
    if(f) return x;
    return 0-x;
}
int n,c[N]; ll ans;
struct edge{int v,nxt;}e[N<<1];
int hd[N],cnt;
inline void add(int u, int v){e[++cnt]=(edge){v,hd[u]}, hd[u]=cnt;}
int Siz,rt,rt_mxson,siz[N];
void getRoot(int u, int fa){
    siz[u]=1; int mxson=0;
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa)
        getRoot(e[i].v,u), siz[u]+=siz[e[i].v], mxson=max(mxson,siz[e[i].v]);
    mxson=max(mxson,Siz-siz[u]);
    if(mxson<rt_mxson) rt_mxson=mxson, rt=u;
}
  
struct SegTree{
    #define ls o<<1
    #define rs o<<1|1
    int sum[N<<2],tag[N<<2]; bool clr[N<<2];
    inline void pushup(int o){sum[o]=sum[ls]+sum[rs];}
    inline void pushdown(int o, int l, int r, int mid){
        if(clr[o]){
            sum[ls]=sum[rs]=tag[ls]=tag[rs]=0,
            clr[ls]=clr[rs]=1,
            clr[o]=0;
        }
        if(tag[o]){
            sum[ls]+=(mid-l+1)*tag[o],
            sum[rs]+=(r-mid)*tag[o],
            tag[ls]+=tag[o],
            tag[rs]+=tag[o],
            tag[o]=0;
        }
    }
    void mdf(int o, int l, int r, int L, int R, int v){
        if(L<=l && r<=R){sum[o]+=(r-l+1)*v, tag[o]+=v; return;}
        int mid=l+r>>1;
        pushdown(o,l,r,mid);
        if(L<=mid) mdf(ls,l,mid,L,R,v);
        if(R>mid) mdf(rs,mid+1,r,L,R,v);
        pushup(o);
    }
    int query(int o, int l, int r, int L, int R){
        if(L<=l && r<=R) return sum[o];
        int mid=l+r>>1, ret=0;
        pushdown(o,l,r,mid);
        if(L<=mid) ret+=query(ls,l,mid,L,R);
        if(R>mid) ret+=query(rs,mid+1,r,L,R);
        return ret;
    }
    #undef ls
    #undef rs
}sgt1,sgt2;
  
int suf_sum;
int dfn[N],idx,siz2[N]; vector<int> p[N];
void dfs1(int u, int fa){
    dfn[u]=++idx, siz2[u]=1;
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs1(e[i].v,u), siz2[u]+=siz2[e[i].v];
}
bool flag[N][21]; int val[N][21],vis[N];
int V;
void dfs2(int u, int fa){
    int x;
    for(int i=0; i<p[c[u]].size(); ++i) if(!sgt2.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]])){
        //cout<<u<<' '<<p[c[u]][i]<<endl;
        x = min(siz2[p[c[u]][i]],suf_sum+1) - sgt1.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1);
        //cout<<"dfs2:"<<u<<' '<<p[c[u]][i]<<' '<<x<<' '<<dfn[p[c[u]][i]]<<' '<<siz2[p[c[u]][i]]<<' '<<sgt1.query(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1)<<endl;
        sgt1.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]],x);
        sgt2.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1,1); //cout<<i<<' '<<p[c[u]][i]<<endl;
        flag[u][i]=1, val[u][i]=x;
    }
    ans-=sgt1.query(1,1,n,1,Siz)-sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1);
    //cout<<u<<' '<<sgt1.query(1,1,n,1,Siz)<<' '<<sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1)<<endl;
    //cout<<"dfs2:"<<u<<' '<<suf_sum+1<<' '<<sgt1.query(1,1,n,1,Siz)<<' '<<sgt1.query(1,1,n,dfn[V],dfn[V]+siz[V]-1)<<' '<<sgt1.query(1,1,n,1,1)<<endl;
    vis[c[u]]=1, ans+=suf_sum+1;
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa && !vis[c[e[i].v]]) dfs2(e[i].v,u);
    vis[c[u]]=0;
    for(int i=0; i<p[c[u]].size(); ++i) if(flag[u][i]){
        sgt1.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]],-val[u][i]);
        sgt2.mdf(1,1,n,dfn[p[c[u]][i]],dfn[p[c[u]][i]]+siz2[p[c[u]][i]]-1,-1);
        flag[u][i]=0;
    }
}
void dfs3(int u, int fa){
    p[c[u]].push_back(u);
    if(vis[c[u]]==1 && !sgt2.query(1,1,n,dfn[u],dfn[u])) sgt1.mdf(1,1,n,dfn[u],dfn[u],siz2[u]), sgt2.mdf(1,1,n,dfn[u],dfn[u]+siz2[u]-1,1);
    ++vis[c[u]];
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs3(e[i].v,u);
    --vis[c[u]];
}
void dfs4(int u, int fa){
    p[c[u]].pop_back();
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v && e[i].v!=fa) dfs4(e[i].v,u);
}
  
int getAns(int u){
    idx=0, dfs1(u,0); ++ans, suf_sum=0;
    p[c[u]].push_back(u), vis[c[u]]=1;
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v){
        V=e[i].v;
        //cout<<V<<endl;
        if(!vis[c[e[i].v]]) dfs2(e[i].v,u);
        dfs3(e[i].v,u);
        //cout<<suf_sum<<' '<<siz2[e[i].v]<<endl;
        suf_sum+=siz2[e[i].v];
        //cout<<"sgt1:"<<sgt1.query(1,1,n,1,Siz)<<endl;
    }
    //cout<<sum<<endl;
    vis[c[u]]=0, dfs4(u,0);
    //cout<<u<<' '<<sum<<' '<<sub_ans<<endl;
    sgt1.sum[1]=sgt1.tag[1]=0, sgt1.clr[1]=1;
    sgt2.sum[1]=sgt2.tag[1]=0, sgt2.clr[1]=1;
}
void solve(int u, int s){
    if(s==1){++ans; return;}
    Siz=s, rt=0, rt_mxson=inf, getRoot(u,u), u=rt;
    //printf("solve:%lld %lld\n",u,s);
    getAns(u);
    for(int i=hd[u]; i; i=e[i].nxt) if(~e[i].v){
        int v=e[i].v;
        e[i].v=e[i&1?i+1:i-1].v=-1;
        solve(v, siz[v]<siz[u]?siz[v]:s-siz[u]);
    }
}
  
signed main(){
    //freopen("b.in","r",stdin);
    //freopen("b2.out","w",stdout);
    n=read();
    for(int i=1; i<=n; ++i) c[i]=read();
    int u,v;
    for(int i=1; i<n; ++i) u=read(), v=read(), add(u,v), add(v,u);
    solve(1,n);
    cout<<ans<<endl;
    return 0;
}

  
天天就会点分治,分到亲马不认 srO

posted @ 2019-09-23 22:28  大本营  阅读(194)  评论(0编辑  收藏  举报