【JZOJ100019】A

Description

这里写图片描述

Solution

结论1: ni=1ni=n ln n

结论2:给一棵树前序遍历,一颗子树的在dfs序上是连续的。

然后我们枚举起点,枚举不能通过的点,那么有些点对是不合法的。

有以下两种情况:
这里写图片描述
第一种(如左图)是两个点互不包含:a节点所在dfs序上所包含的区间为 [la,ra] ,k*a节点为 [lka,rka] ,那么这两个区间的点对是不能互相到达的,即他们不互相到达的两个区间构成了一个左下角为 (la,lka) ,右上角为 [ra,rka] 矩形。

第二种(如右图)是包含的情况,那么k*a点为一个区间 [lka,rka] ,不能通过的点不是一个区间。
我们可以发现,除去b点(在a,b路径上且是a的儿子)的子树,那么剩余的就是不能通过的,也就是两个区间: [1,lb) (rb,n]

于是构造出所有矩形,求矩形覆盖的面积就是不合法路径的方案,用总数减去即可。
求矩形覆盖面积可以用扫描线+线段树。

Solution

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define N 100010
#define M 200010
#define ll long long
using namespace std;
int to[M],nx[M],ls[N],num=0;
void link(int x,int y){
    to[++num]=y,nx[num]=ls[x],ls[x]=num;
}
int dep[N],fa[N][18],L[N],R[N],dfn=0;
int lg[N];
void pre(int x)
{
    L[x]=++dfn;
    rep(i,x)
    {
        int v=to[i];
        if(v==fa[x][0]) continue;
        fa[v][0]=x;
        dep[v]=dep[x]+1;
        pre(v);
    }
    R[x]=dfn;
}
int jump(int u,int v){
    if(dep[u]<dep[v]) swap(u,v);
    fd(i,lg[dep[u]-dep[v]+1],0)
    if(dep[fa[u][i]]>dep[v]) u=fa[u][i];
    return u;
}
struct node{
    int x,l,r;
    int w;
}b[N*50];
struct tree{
    int x,lz;
}tr[N*8];
int tot=0;
void putin(int x,int y,int p,int q){
    if(x>p) swap(x,p);
    if(y>q) swap(y,q);
    if(p<y) swap(x,y),swap(p,q);
    b[++tot].x=x,b[tot].l=y,b[tot].r=q,b[tot].w=1;
    b[++tot].x=p+1,b[tot].l=y,b[tot].r=q,b[tot].w=-1;
}
bool cmp(node x,node y){
    return x.x<y.x;
}
void add(int v,int l,int r,int x,int y,int p)
{
    if(l==x && r==y)
    {
        tr[v].lz+=p;
        if(tr[v].lz) tr[v].x=r-l+1;
        else tr[v].x=tr[v*2].x+tr[v*2+1].x;
        return;
    }
    int mid=(l+r)/2;
    if(y<=mid) add(v*2,l,mid,x,y,p);
    else if(x>mid) add(v*2+1,mid+1,r,x,y,p);
    else add(v*2,l,mid,x,mid,p),add(v*2+1,mid+1,r,mid+1,y,p);
    if(tr[v].lz) tr[v].x=r-l+1;
    else tr[v].x=tr[v*2].x+tr[v*2+1].x;
}
int main()
{
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
    int n;
    scanf("%d",&n);
    fo(i,2,n)
    {
        int x,y;
        scanf("%d %d",&x,&y);
        link(x,y),link(y,x);
        lg[i]=log2(i);
    }
    dep[1]=1;
    pre(1);
    fo(j,1,lg[n])
    fo(i,1,n) fa[i][j]=fa[fa[i][j-1]][j-1];
    fo(i,1,n)
    fo(j,2,n/i)
    {
        int p=i,q=i*j;
        if(L[p]>L[q]) swap(p,q);
        if(L[p]<=L[q] && R[q]<=R[p])
        {
            int qson=jump(p,q);
            int p1=p,q1=q;
            if(L[qson]>1) putin(L[q],1,R[q],L[qson]-1);
            if(R[qson]<n) putin(L[q],R[qson]+1,R[q],n);
        }
        else putin(L[p],L[q],R[p],R[q]);
    }
    sort(b+1,b+tot+1,cmp);
    ll ans=n*1ll*(n-1)/2;
    int c=0;
    fo(i,1,n)
    {
        while(b[c+1].x==i && c<tot) c++,add(1,1,n,b[c].l,b[c].r,b[c].w);
        ans-=tr[1].x;
    }
    printf("%lld",ans);
}
posted @ 2017-06-29 16:08  sadstone  阅读(75)  评论(0编辑  收藏  举报