【JZOJ6074】铁路
Description
Solution
首先列车可能会在边中点相交,给每条边上加一个点,变成求点相交的对数。考虑如何不计算重,先固定根,我们统计两条向上走第一次相交的对数,还有一条向上一条向下的对数。
两条向上可以用线段树(启发式)合并求,就是在起点打加当前深度标记,lca处打减标记,自下往上深度相同时算一下即可。
至于向上向下的有些难处理,考虑链剖,在每一条重链上打标记,具体就是对于一条重链把它看成一个坐标系,横坐标是到链顶的距离,纵坐标是一个列车在某个点上的时间。于是下行或上行都可以看作加y=x+c或y=-x+c的线段,求线段的交点个数。需要注意的是,为了不与前面算重,lca要标记为上行。
至于如何求交,把所有的线段顺(逆)时针转45度,扫描线+树状数组即可解决。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define pb push_back
using namespace std;
typedef long long ll;
const int N=2e5+10,M=4e5+10,inf=1e9+10;
int to[M],nx[M],ls[N],num=0;
void link(int u,int v){
to[++num]=v,nx[num]=ls[u],ls[u]=num;
}
int top[N],z[N],son[N],dep[N],fa[N],tot=0,tt=0;
int sz[N],rt[N],cn[N],n;
struct node{
int v,l,r,s;
}tr[N*50];
struct P{
int k,b,l,r;
P(int _k=0,int _b=0,int _l=0,int _r=0) {k=_k,b=_b,l=_l,r=_r;}
};
ll ans=0;
struct PP{
int p,x,y;
PP(int _p=0,int _x=0,int _y=0) {p=_p,x=_x,y=_y;}
}c[N],d[N];
vector<int> tag[N];
vector<P> b[N];
void ins(int x,int t,int &v,int l=1,int r=n){
if(!v) v=++tot;
tr[v].s+=t;
if(l==r) return;
int mid=(l+r)>>1;
x<=mid?ins(x,t,tr[v].l,l,mid):ins(x,t,tr[v].r,mid+1,r);
}
void merge(int &v,int v1,int l=1,int r=n){
if(!v) return void(v=v1);
if(!v1) return;
if(l==r) ans+=(ll)tr[v].s*tr[v1].s;
tr[v].s+=tr[v1].s;
if(l==r) return;
int mid=(l+r)>>1;
merge(tr[v].l,tr[v1].l,l,mid);
merge(tr[v].r,tr[v1].r,mid+1,r);
}
void pre(int x){
dep[x]=dep[fa[x]]+1,sz[x]=1;
rep(i,x){
int v=to[i];
if(v==fa[x]) continue;
fa[v]=x,pre(v),sz[x]+=sz[v];
if(sz[son[x]]<sz[v]) son[x]=v;
}
}
void pre2(int x,int t){
top[x]=t,z[x]=z[t];
if(son[x]) pre2(son[x],t);
rep(i,x){
int v=to[i];
if(v==fa[x] || v==son[x]) continue;
z[v]=++tt,pre2(v,v);
}
}
void pre3(int x){
rep(i,x){
int v=to[i];
if(v==fa[x]) continue;
pre3(v);
merge(rt[x],rt[v]);
}
if(cn[x]) ans+=(ll)cn[x]*(cn[x]-1)/2,ins(dep[x],cn[x],rt[x]);
int o=tag[x].size();
fo(i,0,o-1) ins(tag[x][i],-1,rt[x]);
}
int lca(int u,int v){
for(;top[u]!=top[v];u=fa[top[u]])
if(dep[top[u]]<dep[top[v]]) swap(u,v);
if(dep[u]<dep[v]) swap(u,v);
return v;
}
void fun(int s,int t){
int u=s,v=t;
int lc=lca(u,v);++cn[u],tag[lc].pb(dep[u]);
int ff=top[u];
for(;ff!=top[lc];u=fa[ff],ff=top[u])
b[z[u]].pb(P(-1,dep[s]-dep[ff],0,dep[u]-dep[ff]));
ff=top[v];
for(;ff!=top[lc];v=fa[ff],ff=top[v])
b[z[v]].pb(P(1,dep[s]+dep[ff]-2*dep[lc],0,dep[v]-dep[ff]));
int l=dep[u]-dep[ff],r=dep[v]-dep[ff];
if(u==lc){
b[z[u]].pb(P(-1,dep[s]-dep[u]+l,l,l));
b[z[u]].pb(P(1,dep[s]-dep[u]-l,l+1,r));
}
else b[z[u]].pb(P(-1,dep[s]-dep[u]+l,r,l));
}
bool cmp(PP x,PP y) {return x.x<y.x;}
bool cmp1(PP x,PP y) {return x.p<y.p;}
int cr[N*3];
int mx;
void add(int x,int t) {for(;x<=mx;x+=x&-x) cr[x]+=t;}
int sum(int x){
int tmp=0;
for(;x;x-=x&-x) tmp+=cr[x];
return tmp;
}
void solve(int now){
int o=b[now].size();
int t1=0,t2=0;mx=0;
fo(i,0,o-1){
P x=b[now][i];
if(x.k==1) c[++t1]=PP(x.b,x.b+x.l*2,1),c[++t1]=PP(x.b,x.b+x.r*2+1,-1);
else d[++t2]=PP(x.b,x.b-2*x.r,x.b-2*x.l);
mx=max(mx,x.b);
}
int mn=inf;
sort(c+1,c+t1+1,cmp);
sort(d+1,d+t2+1,cmp1);
fo(i,1,t1) mn=min(mn,c[i].p);
fo(i,1,t2) mn=min(mn,d[i].x);--mn;
fo(i,1,t1) c[i].p-=mn;
fo(i,1,t2) d[i].x-=mn,d[i].y-=mn;mx-=mn;
int p=0;
fo(i,1,t2){
for(;p<t1 && c[p+1].x<=d[i].p;) ++p,add(c[p].p,c[p].y);
ans+=sum(d[i].y)-sum(d[i].x-1);
}
for(;p<t1;) ++p,add(c[p].p,c[p].y);
}
int main()
{
freopen("train.in","r",stdin);
freopen("train.out","w",stdout);
scanf("%d",&n);
int nn=n;
fo(i,2,nn){
int u,v;
scanf("%d %d",&u,&v),++n;
link(u,n),link(n,u);
link(n,v),link(v,n);
}
pre(1);
tot=0,z[1]=tt=1,pre2(1,1);
int m;
scanf("%d",&m);
fo(i,1,m){
int u,v;
scanf("%d %d",&u,&v);
fun(u,v);
}
pre3(1);
fo(i,1,tt) solve(i);
printf("%lld",ans);
}