【JZOJ5678】果树

Description

有一棵n个节点的树,每个点有颜色,同种颜色在树上不超过t(t<=20)次,求不出现同种颜色的简单路径条数,(u,v)和(v,u)视为同一条路径,(u,u)也算一条路径。

Solution

考虑什么样的点对(路径)不能选。

设有一个 n×n 的矩形,矩形上 (x,y) 为黑色表示路径 (x,y) 为不合法,若为白色则合法。那么考虑同种颜色的一对点,它会在矩形上把很多个格子染黑,然而最后统计还是 n2 的。

我们发现,对于一个限制点对,它染的格子 (xi,yi) xi 在树的dfs序上肯定是连续的一段或两段, yi 也是,所以我们修改矩形的定义,改为dfs序为x到dfs序为y的路径的合法情况。于是每个限制相当于在矩形上覆盖几个小矩形。

用扫描线统计,线段树标记永久化即可。

分析一下时间复杂度,每种颜色出现t次,枚举每对就是 t2 的复杂度,此时有 n/t 种颜色,粗略计算就是 O(ntlog2n)

Code

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
using namespace std;
typedef long long ll;
const int N=2e5+10,M=2e5+10;
int to[M],nx[M],ls[N],num=0;
void link(int u,int v){
    to[++num]=v,nx[num]=ls[u],ls[u]=num;
}
struct node{
    int x,l,r,v;
}a[N*40];
bool cmp(node x,node y){
    return x.x<y.x;
}
vector<int> c[N];
int L[N],R[N],dep[N],tot=0;
int fa[N][20];
int n;
void pre(int x,int fr){
    L[x]=++tot,dep[x]=dep[fr]+1;
    rep(i,x){
        int v=to[i];
        if(v==fr) continue;
        fa[v][0]=x;
        pre(v,x);
    }
    R[x]=tot;
}
void putin(int x1,int x2,int y1,int y2){
    if(x1>x2 || y1>y2) return;
    a[++tot].x=x1,a[tot].l=y1,a[tot].r=y2,a[tot].v=1;
    a[++tot].x=x2+1,a[tot].l=y1,a[tot].r=y2,a[tot].v=-1;
}
int get(int x,int t){
    fd(i,17,0) if(dep[fa[x][i]]>dep[t]) x=fa[x][i];
    return x;
}
void put(int x,int y){
    if(L[x]<=L[y] && R[x]>=R[y]){
        int p=get(y,x);
        putin(1,L[p]-1,L[y],R[y]);
        putin(R[p]+1,n,L[y],R[y]);
    }
    else if(L[y]<=L[x] && R[y]>=R[x]){
        int p=get(x,y);
        putin(L[x],R[x],1,L[p]-1);
        putin(L[x],R[x],R[p]+1,n);
    }
    else putin(L[x],R[x],L[y],R[y]);
}
int tr[N<<2],lz[N<<2];
int max(int x,int y){
    return x>y?x:y;
}
void add(int v,int l,int r,int x,int y,int t){
    if(l==x && r==y) {
        lz[v]+=t,tr[v]=lz[v]>0?r-l+1:tr[v<<1]+tr[(v<<1)+1];
        return;
    }
    int mid=(l+r)>>1;
    if(y<=mid) add(v<<1,l,mid,x,y,t);
    else if(x>mid) add((v<<1)+1,mid+1,r,x,y,t);
    else add(v<<1,l,mid,x,mid,t),add((v<<1)+1,mid+1,r,mid+1,y,t);
    tr[v]=lz[v]?r-l+1:tr[v<<1]+tr[(v<<1)+1];
}
int main()
{
    freopen("tree.in","r",stdin);
    freopen("tree.out","w",stdout);
    scanf("%d",&n);
    fo(i,1,n){
        int x;
        scanf("%d",&x);
        c[x].push_back(i);
    }
    fo(i,2,n){
        int u,v;
        scanf("%d %d",&u,&v);
        link(u,v),link(v,u);
    }
    pre(1,0);
    fo(j,1,17)
    fo(i,1,n) fa[i][j]=fa[fa[i][j-1]][j-1];
    tot=0;
    fo(i,1,n){
        int o=c[i].size();
        if(o<2) continue;
        fo(j,1,o-1)
        fo(k,0,j-1) put(c[i][k],c[i][j]),put(c[i][j],c[i][k]);
    }
    sort(a+1,a+tot+1,cmp);
    ll ans=0;
    int p=0;
    fo(i,1,n+1){
        while(p<tot && a[p+1].x==i){
            p++;
            add(1,1,n,a[p].l,a[p].r,a[p].v);
        }
        ans+=tr[1];
    }
    ans=(ll)n*n-ans;
    printf("%lld",(ans-n)/2+n);
}
posted @ 2018-04-23 21:50  sadstone  阅读(89)  评论(0编辑  收藏  举报