【JZOJ100019】A
Description
Solution
结论1: ∑ni=1ni=n ln n
结论2:给一棵树前序遍历,一颗子树的在dfs序上是连续的。
然后我们枚举起点,枚举不能通过的点,那么有些点对是不合法的。
有以下两种情况:
第一种(如左图)是两个点互不包含:a节点所在dfs序上所包含的区间为
[la,ra]
,k*a节点为
[lk∗a,rk∗a]
,那么这两个区间的点对是不能互相到达的,即他们不互相到达的两个区间构成了一个左下角为
(la,lk∗a)
,右上角为
[ra,rk∗a]
矩形。
第二种(如右图)是包含的情况,那么k*a点为一个区间
[lk∗a,rk∗a]
,不能通过的点不是一个区间。
我们可以发现,除去b点(在a,b路径上且是a的儿子)的子树,那么剩余的就是不能通过的,也就是两个区间:
[1,lb)
,
(rb,n]
。
于是构造出所有矩形,求矩形覆盖的面积就是不合法路径的方案,用总数减去即可。
求矩形覆盖面积可以用扫描线+线段树。
Solution
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define N 100010
#define M 200010
#define ll long long
using namespace std;
int to[M],nx[M],ls[N],num=0;
void link(int x,int y){
to[++num]=y,nx[num]=ls[x],ls[x]=num;
}
int dep[N],fa[N][18],L[N],R[N],dfn=0;
int lg[N];
void pre(int x)
{
L[x]=++dfn;
rep(i,x)
{
int v=to[i];
if(v==fa[x][0]) continue;
fa[v][0]=x;
dep[v]=dep[x]+1;
pre(v);
}
R[x]=dfn;
}
int jump(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
fd(i,lg[dep[u]-dep[v]+1],0)
if(dep[fa[u][i]]>dep[v]) u=fa[u][i];
return u;
}
struct node{
int x,l,r;
int w;
}b[N*50];
struct tree{
int x,lz;
}tr[N*8];
int tot=0;
void putin(int x,int y,int p,int q){
if(x>p) swap(x,p);
if(y>q) swap(y,q);
if(p<y) swap(x,y),swap(p,q);
b[++tot].x=x,b[tot].l=y,b[tot].r=q,b[tot].w=1;
b[++tot].x=p+1,b[tot].l=y,b[tot].r=q,b[tot].w=-1;
}
bool cmp(node x,node y){
return x.x<y.x;
}
void add(int v,int l,int r,int x,int y,int p)
{
if(l==x && r==y)
{
tr[v].lz+=p;
if(tr[v].lz) tr[v].x=r-l+1;
else tr[v].x=tr[v*2].x+tr[v*2+1].x;
return;
}
int mid=(l+r)/2;
if(y<=mid) add(v*2,l,mid,x,y,p);
else if(x>mid) add(v*2+1,mid+1,r,x,y,p);
else add(v*2,l,mid,x,mid,p),add(v*2+1,mid+1,r,mid+1,y,p);
if(tr[v].lz) tr[v].x=r-l+1;
else tr[v].x=tr[v*2].x+tr[v*2+1].x;
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
int n;
scanf("%d",&n);
fo(i,2,n)
{
int x,y;
scanf("%d %d",&x,&y);
link(x,y),link(y,x);
lg[i]=log2(i);
}
dep[1]=1;
pre(1);
fo(j,1,lg[n])
fo(i,1,n) fa[i][j]=fa[fa[i][j-1]][j-1];
fo(i,1,n)
fo(j,2,n/i)
{
int p=i,q=i*j;
if(L[p]>L[q]) swap(p,q);
if(L[p]<=L[q] && R[q]<=R[p])
{
int qson=jump(p,q);
int p1=p,q1=q;
if(L[qson]>1) putin(L[q],1,R[q],L[qson]-1);
if(R[qson]<n) putin(L[q],R[qson]+1,R[q],n);
}
else putin(L[p],L[q],R[p],R[q]);
}
sort(b+1,b+tot+1,cmp);
ll ans=n*1ll*(n-1)/2;
int c=0;
fo(i,1,n)
{
while(b[c+1].x==i && c<tot) c++,add(1,1,n,b[c].l,b[c].r,b[c].w);
ans-=tr[1].x;
}
printf("%lld",ans);
}