【JZOJ6074】铁路

Description

在这里插入图片描述

Solution

首先列车可能会在边中点相交,给每条边上加一个点,变成求点相交的对数。考虑如何不计算重,先固定根,我们统计两条向上走第一次相交的对数,还有一条向上一条向下的对数。
两条向上可以用线段树(启发式)合并求,就是在起点打加当前深度标记,lca处打减标记,自下往上深度相同时算一下即可。
至于向上向下的有些难处理,考虑链剖,在每一条重链上打标记,具体就是对于一条重链把它看成一个坐标系,横坐标是到链顶的距离,纵坐标是一个列车在某个点上的时间。于是下行或上行都可以看作加y=x+c或y=-x+c的线段,求线段的交点个数。需要注意的是,为了不与前面算重,lca要标记为上行。
至于如何求交,把所有的线段顺(逆)时针转45度,扫描线+树状数组即可解决。

Code

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define pb push_back
using namespace std;
typedef long long ll;
const int N=2e5+10,M=4e5+10,inf=1e9+10;
int to[M],nx[M],ls[N],num=0;
void link(int u,int v){
	to[++num]=v,nx[num]=ls[u],ls[u]=num;
}
int top[N],z[N],son[N],dep[N],fa[N],tot=0,tt=0;
int sz[N],rt[N],cn[N],n;
struct node{
	int v,l,r,s;
}tr[N*50];
struct P{
	int k,b,l,r;
	P(int _k=0,int _b=0,int _l=0,int _r=0) {k=_k,b=_b,l=_l,r=_r;}
};
ll ans=0;
struct PP{
	int p,x,y;
	PP(int _p=0,int _x=0,int _y=0) {p=_p,x=_x,y=_y;}
}c[N],d[N];
vector<int> tag[N];
vector<P> b[N];
void ins(int x,int t,int &v,int l=1,int r=n){
	if(!v) v=++tot;
	tr[v].s+=t;
	if(l==r) return;
	int mid=(l+r)>>1;
	x<=mid?ins(x,t,tr[v].l,l,mid):ins(x,t,tr[v].r,mid+1,r);
}
void merge(int &v,int v1,int l=1,int r=n){
	if(!v) return void(v=v1);
	if(!v1) return;
	if(l==r) ans+=(ll)tr[v].s*tr[v1].s;
	tr[v].s+=tr[v1].s;
	if(l==r) return;
	int mid=(l+r)>>1;
	merge(tr[v].l,tr[v1].l,l,mid);
	merge(tr[v].r,tr[v1].r,mid+1,r);
}
void pre(int x){
	dep[x]=dep[fa[x]]+1,sz[x]=1;
	rep(i,x){
		int v=to[i];
		if(v==fa[x]) continue;
		fa[v]=x,pre(v),sz[x]+=sz[v];
		if(sz[son[x]]<sz[v]) son[x]=v; 
	}
}
void pre2(int x,int t){
	top[x]=t,z[x]=z[t];
	if(son[x]) pre2(son[x],t);
	rep(i,x){
		int v=to[i];
		if(v==fa[x] || v==son[x]) continue;
		z[v]=++tt,pre2(v,v);
	}
}
void pre3(int x){
	rep(i,x){
		int v=to[i];
		if(v==fa[x]) continue;
		pre3(v);
		merge(rt[x],rt[v]);
	}
	if(cn[x]) ans+=(ll)cn[x]*(cn[x]-1)/2,ins(dep[x],cn[x],rt[x]);
	int o=tag[x].size();
	fo(i,0,o-1) ins(tag[x][i],-1,rt[x]);
}
int lca(int u,int v){
	for(;top[u]!=top[v];u=fa[top[u]])
	if(dep[top[u]]<dep[top[v]]) swap(u,v);
	if(dep[u]<dep[v]) swap(u,v);
	return v;
}
void fun(int s,int t){
	int u=s,v=t;
	int lc=lca(u,v);++cn[u],tag[lc].pb(dep[u]);
	int ff=top[u];
	for(;ff!=top[lc];u=fa[ff],ff=top[u])
	b[z[u]].pb(P(-1,dep[s]-dep[ff],0,dep[u]-dep[ff]));
	ff=top[v];
	for(;ff!=top[lc];v=fa[ff],ff=top[v])
	b[z[v]].pb(P(1,dep[s]+dep[ff]-2*dep[lc],0,dep[v]-dep[ff]));
	int l=dep[u]-dep[ff],r=dep[v]-dep[ff];
	if(u==lc){
		b[z[u]].pb(P(-1,dep[s]-dep[u]+l,l,l));
		b[z[u]].pb(P(1,dep[s]-dep[u]-l,l+1,r));
	}
	else b[z[u]].pb(P(-1,dep[s]-dep[u]+l,r,l));
}
bool cmp(PP x,PP y) {return x.x<y.x;}
bool cmp1(PP x,PP y) {return x.p<y.p;}
int cr[N*3];
int mx;
void add(int x,int t) {for(;x<=mx;x+=x&-x) cr[x]+=t;}
int sum(int x){
	int tmp=0;
	for(;x;x-=x&-x) tmp+=cr[x];
	return tmp;
}
void solve(int now){
	int o=b[now].size();
	int t1=0,t2=0;mx=0;
	fo(i,0,o-1){
		P x=b[now][i];
		if(x.k==1) c[++t1]=PP(x.b,x.b+x.l*2,1),c[++t1]=PP(x.b,x.b+x.r*2+1,-1);
		else d[++t2]=PP(x.b,x.b-2*x.r,x.b-2*x.l);
		mx=max(mx,x.b);
	}
	int mn=inf;
	sort(c+1,c+t1+1,cmp);
	sort(d+1,d+t2+1,cmp1);
	fo(i,1,t1) mn=min(mn,c[i].p);
	fo(i,1,t2) mn=min(mn,d[i].x);--mn;
	fo(i,1,t1) c[i].p-=mn;
	fo(i,1,t2) d[i].x-=mn,d[i].y-=mn;mx-=mn;
	int p=0;
	fo(i,1,t2){
		for(;p<t1 && c[p+1].x<=d[i].p;) ++p,add(c[p].p,c[p].y);
		ans+=sum(d[i].y)-sum(d[i].x-1);
	}
	for(;p<t1;) ++p,add(c[p].p,c[p].y);
}
int main()
{
	freopen("train.in","r",stdin);
	freopen("train.out","w",stdout);
	scanf("%d",&n);
	int nn=n;
	fo(i,2,nn){
		int u,v;
		scanf("%d %d",&u,&v),++n;
		link(u,n),link(n,u);
		link(n,v),link(v,n);
	}
	pre(1);
	tot=0,z[1]=tt=1,pre2(1,1);
	int m;
	scanf("%d",&m);
	fo(i,1,m){
		int u,v;
		scanf("%d %d",&u,&v);
		fun(u,v);
	}
	pre3(1);
	fo(i,1,tt) solve(i);
	printf("%lld",ans);
}

posted @ 2019-03-21 14:56  sadstone  阅读(43)  评论(0编辑  收藏  举报