看到这种树上统计点对个数的题一般是线段树合并,这题也不出意外
先对这棵树进行树剖,对于每次普及语言,在\(x,y\)两点的线段树上的\(x,y\)两位置打\(+1\)标记,在点\(fa[lca(x,y)]\)的线段树上\(x,y\)两位置打\(-2\)标记
线段树中该维护三个东西:
1.dfs序最小的\(lp\)
2.dfs序最大的\(rp\)
3.线段树中所有被打标机的点到根节点路径的并的节点个数\(sum\)
我们进行搜索并从下向上的进行线段树合并,对于每个节点,对答案的贡献就是该节点线段树的根节点的\(sum-dep[lca(lp,rp)]+1-1\)(珂以画图理解一下)
交上去,发现wa了,因为每条路径的贡献被我们算了\(2\)次,所以除以2即可
#include <bits/stdc++.h>
#define N 100005
#define ll long long
#define getchar nc
using namespace std;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
register int x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*f;
}
inline void write(register ll x)
{
if(!x)putchar('0');if(x<0)x=-x,putchar('-');
static int sta[20];register int tot=0;
while(x)sta[tot++]=x%10,x/=10;
while(tot)putchar(sta[--tot]+48);
}
int n,m;
ll ans;
struct edge{
int to,next;
}e[N<<1];
int head[N],cnt=0;
inline void add(register int u,register int v)
{
e[++cnt]=(edge){v,head[u]};
head[u]=cnt;
}
int size[N],fa[N],son[N],dep[N],top[N],dfn[N],tim,idfn[N];
inline void dfs1(register int x)
{
size[x]=1;
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[x])
continue;
dep[v]=dep[x]+1;
fa[v]=x;
dfs1(v);
size[x]+=size[v];
if(size[v]>size[son[x]])
son[x]=v;
}
}
inline void dfs2(register int x,register int t)
{
dfn[x]=++tim;
idfn[tim]=x;
top[x]=t;
if(son[x])
dfs2(son[x],t);
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v!=fa[x]&&v!=son[x])
dfs2(v,v);
}
}
inline int getlca(register int x,register int y)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])
{
x^=y^=x^=y;
fx^=fy^=fx^=fy;
}
x=fa[fx];
fx=top[x];
}
return dep[x]<dep[y]?x:y;
}
int root[N],tot;
struct node{
int ls,rs,cnt,lp,rp,sum;
}tr[N*60];
inline void pushup(register int x)
{
tr[x].lp=tr[tr[x].ls].lp?tr[tr[x].ls].lp:tr[tr[x].rs].lp;
tr[x].rp=tr[tr[x].rs].rp?tr[tr[x].rs].rp:tr[tr[x].ls].rp;
tr[x].sum=tr[tr[x].ls].sum+tr[tr[x].rs].sum;
if(tr[tr[x].ls].rp&&tr[tr[x].rs].lp)
tr[x].sum-=dep[getlca(tr[tr[x].ls].rp,tr[tr[x].rs].lp)];
}
inline void modify(register int &x,register int l,register int r,register int pos,register int val)
{
if(!x)
x=++tot;
if(l==r)
{
tr[x].cnt+=val;
if(tr[x].cnt)
tr[x].lp=tr[x].rp=pos,tr[x].sum=dep[pos];
else
tr[x].lp=tr[x].rp=tr[x].sum=0;
return;
}
int mid=l+r>>1;
if(dfn[pos]<=mid)
modify(tr[x].ls,l,mid,pos,val);
else
modify(tr[x].rs,mid+1,r,pos,val);
pushup(x);
}
inline void merge(register int &a,register int b,register int l,register int r)
{
if(!a||!b)
{
a=a+b;
return;
}
if(l==r)
{
tr[a].cnt+=tr[b].cnt;
if(tr[a].cnt)
tr[a].lp=tr[a].rp=idfn[l],tr[a].sum=dep[idfn[l]];
else
tr[a].lp=tr[a].rp=tr[a].sum=0;
return;
}
int mid=l+r>>1;
merge(tr[a].ls,tr[b].ls,l,mid);
merge(tr[a].rs,tr[b].rs,mid+1,r);
pushup(a);
}
inline void dfs(register int x)
{
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[x])
continue;
dfs(v);
merge(root[x],root[v],1,tim);
}
if(tr[root[x]].lp&&tr[root[x]].rp)
ans+=tr[root[x]].sum-dep[getlca(tr[root[x]].lp,tr[root[x]].rp)];
}
int main()
{
n=read(),m=read();
for(register int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs1(1);
dfs2(1,0);
for(register int i=1;i<=m;++i)
{
int x=read(),y=read();
int lca=getlca(x,y);
modify(root[x],1,tim,x,1);
modify(root[x],1,tim,y,1);
modify(root[y],1,tim,x,1);
modify(root[y],1,tim,y,1);
if(lca>1)
{
modify(root[fa[lca]],1,tim,x,-2);
modify(root[fa[lca]],1,tim,y,-2);
}
}
dfs(1);
write(ans>>1);
return 0;
}