LOJ6276:果树——题解

https://loj.ac/problem/6276#submit_code

NiroBC 姐姐是个活泼的少女,她十分喜欢爬树,而她家门口正好有一棵果树,正好满足了她爬树的需求。
这颗果树有N 个节点,节点标号1……N。每个节点长着一个果子,第i 个节点上的果子颜色为Ci。
NiroBC 姐姐每天都要爬树,每天都要选择一条有趣的路径(u,v) 来爬。
一条路径被称作有趣的,当且仅当这条路径上的果子的颜色互不相同。
(u,v) 和(v,u) 被视作同一条路径。特殊地,(i,i) 也被视作一条路径,这条路径只含i 一个果子,显然是有趣的。
NiroBC 姐姐想知道这颗树上有多少条有趣的路径。

线段树好题(当然也很毒)。

考虑u和v同色时的不合法路径,分两种情况讨论:

1.u和v有一个不同于二者的lca

显然不合法路径的两端一个在u的子树中,一个在v的子树中。

2.v是u的祖先。

显然不合法路径的两端一个在u的子树中,一个在v的子树的补集(包括v)中。

同时我们用dfs序定义(u,v)为起点为u终点为v的路径,这样我们可以发现不合法路径的集合恰好围成了多个矩阵。

那么就是扫描线求矩形的并了(不会的话可以看POJ1389:Area of Simple Polygons),当然这是不合法路径,你需要取反后把对角线加上/2才行。

(PPPS:对角线的点显然不会属于任何一个矩形,且其统计的方法不能/2,故需要加上。)

#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef double dl;
const int N=1e5+5;
const int B=17;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
struct path{
    int to,nxt;
}e[N*2];
struct node{
    int x1,x2,y,w;
}edge[N*160];
vector<int>c[N];
int n,cnt,head[N],pos[N],tot,num;
int anc[N][B+4],dep[N],size[N];
ll tr[N*4],lazy[N*4];
inline bool cmp(node a,node b){
    return a.y<b.y;
}
inline void add(int u,int v){
    e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
inline int LCA(int i,int j){
    if(dep[i]<dep[j])swap(i,j);
    for(int k=B;k>=0;k--)
        if(dep[anc[i][k]]>=dep[j])i=anc[i][k];
    if(i==j)return i;
    for(int k=B;k>=0;k--)
        if(anc[i][k]!=anc[j][k])
            i=anc[i][k],j=anc[j][k];
    return anc[i][0];
}
void dfs(int u){
    pos[u]=++tot;size[u]=1;
    dep[u]=dep[anc[u][0]]+1;
    for(int i=1;i<=B;i++)
        anc[u][i]=anc[anc[u][i-1]][i-1];
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=anc[u][0]){
            anc[v][0]=u;
            dfs(v);
            size[u]+=size[v];
        }
    }
}
void ins(int a,int l,int r,int l1,int r1,int w){
    if(r1<l||l1>r||l1>r1)return;
    if(l1<=l&&r<=r1){
        lazy[a]+=w;
        if(lazy[a]>0)tr[a]=r-l+1;
        else if(l==r)tr[a]=0;
        else tr[a]=tr[a*2]+tr[a*2+1];
        return;
    }
    int mid=(l+r)>>1;
    ins(a*2,l,mid,l1,r1,w);ins(a*2+1,mid+1,r,l1,r1,w);
    if(!lazy[a])tr[a]=tr[a*2]+tr[a*2+1];
}
int main(){
    ll n=read();
    for(int i=1;i<=n;i++)c[read()].push_back(i);
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v);add(v,u);
    }
    dfs(1);
    for(int i=1;i<=n;i++){
        for(int j=0;j<c[i].size();j++){
            for(int l=j+1;l<c[i].size();l++){
                int u=c[i][j],v=c[i][l];
                int lca=LCA(u,v);
                if(lca!=u&&lca!=v){
                    int x1=pos[u],y1=pos[v];
                    int x2=x1+size[u]-1,y2=y1+size[v]-1;
                    edge[++num]=(node){x1-1,x2,y1,1};
                    edge[++num]=(node){x1-1,x2,y2+1,-1};
                    edge[++num]=(node){y1-1,y2,x1,1};
                    edge[++num]=(node){y1-1,y2,x2+1,-1};
                }else{
                    if(dep[u]>dep[v])swap(u,v);
                    int p=v;
                    for(int k=B;k>=0;k--)
                        if(dep[anc[p][k]]>=dep[u]+1)p=anc[p][k];
                    int x1=pos[p],y1=pos[v];
                    int x2=x1+size[p]-1,y2=y1+size[v]-1;
                    edge[++num]=(node){0,x1-1,y1,1};
                    edge[++num]=(node){0,x1-1,y2+1,-1};
                    edge[++num]=(node){y1-1,y2,1,1};
                    edge[++num]=(node){y1-1,y2,x1,-1};
                    
                    edge[++num]=(node){x2,n,y1,1};
                    edge[++num]=(node){x2,n,y2+1,-1};
                    edge[++num]=(node){y1-1,y2,x2+1,1};
                    edge[++num]=(node){y1-1,y2,n+1,-1};
                }
            }
        }
    }
    ll ans=0;
    sort(edge+1,edge+num+1,cmp);
    ins(1,1,n,edge[1].x1+1,edge[1].x2,edge[1].w);
    for(int i=2;i<=num;i++){
        ll h=edge[i].y-edge[i-1].y;
        ans+=h*tr[1];
        ins(1,1,n,edge[i].x1+1,edge[i].x2,edge[i].w);
    }
    printf("%lld\n",(n*(n+1)-ans)>>1);
    return 0;
}

+++++++++++++++++++++++++++++++++++++++++++

+本文作者:luyouqi233。               +

+欢迎访问我的博客:http://www.cnblogs.com/luyouqi233/+

+++++++++++++++++++++++++++++++++++++++++++

posted @ 2018-04-17 22:54  luyouqi233  阅读(1557)  评论(0编辑  收藏  举报