CF1486F Pairs of Paths 总结--zhengjun
需要保持:
- 写代码前先仔细考虑一下细节,分类讨论清楚再开始码。
警告:
- namespace 里面写了个 n,想调用全局 n 的时候没加 2*冒号。
思路大概就是分类讨论然后计数就完事了。
代码
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=3e5+10;
int n,m;
vector<int>to[N];
int fa[N],dep[N],bg[N],ed[N];
namespace T{
int dft,pos[N],siz[N],son[N],top[N];
void dfs1(int u){
dep[u]=dep[fa[u]]+1,siz[u]=1;
for(int v:to[u])if(v^fa[u]){
fa[v]=u,dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t,pos[bg[u]=++dft]=u;
if(son[u])dfs2(son[u],t);
for(int v:to[u])if(v^fa[u]&&v^son[u])dfs2(v,v);
ed[u]=dft;
}
void init(){
dfs1(1),dfs2(1,1);
}
int LCA(int u,int v){
for(;top[u]^top[v];u=fa[top[u]]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
}
return dep[u]<dep[v]?u:v;
}
int kth(int u,int k){
for(;dep[u]-dep[top[u]]<k;u=fa[top[u]])k-=dep[u]-dep[top[u]]+1;
return pos[bg[u]-k];
}
}
using T::LCA;
using T::kth;
int cnt[N];
struct path{
int u,v,t,x,y;
}a[N];
ll ans;
namespace Solve1{
void dfs(int u,int fa=0){
cnt[u]+=cnt[fa];
for(int v:to[u])if(v^fa)dfs(v,u);
}
void calc(){
for(int i=1;i<=n;i++)ans+=cnt[i]*(cnt[i]-1ll)/2;
dfs(1);
for(int i=1;i<=m;i++){
ans+=cnt[a[i].u]+cnt[a[i].v]-cnt[a[i].t]-cnt[fa[a[i].t]];
}
}
}
namespace Solve2{
void dfs(int u,int fa=0){
for(int v:to[u])if(v^fa){
dfs(v,u),cnt[u]+=cnt[v];
}
}
void calc(){
fill(cnt,cnt+1+n,0);
for(int i=1;i<=m;i++){
if(a[i].u^a[i].t){
a[i].x=kth(a[i].u,dep[a[i].u]-dep[a[i].t]-1);
cnt[a[i].u]++,cnt[a[i].x]--;
}
if(a[i].v^a[i].t){
a[i].y=kth(a[i].v,dep[a[i].v]-dep[a[i].t]-1);
cnt[a[i].v]++,cnt[a[i].y]--;
}
}
dfs(1);
}
}
namespace Solve3{
int tot[N];
void calc(){
for(int i=1;i<=m;i++)tot[a[i].t]++;
for(int i=1;i<=m;i++){
if(a[i].u^a[i].t)ans+=tot[a[i].u];
if(a[i].v^a[i].t)ans+=tot[a[i].v];
}
}
}
namespace Solve4{
int n;
int sum[N];
int f[N],g[N];
void calc(){
for(int i=1;i<=m;i++){
if(a[i].x>a[i].y)swap(a[i].x,a[i].y),swap(a[i].u,a[i].v);
}
sort(a+1,a+1+m,[](path x,path y){
return x.t^y.t?x.t<y.t:(x.x^y.x?x.x<y.x:x.y<y.y);
});
for(int i=1,j;i<=m;i=j){
for(j=i;j<=m&&a[j].t==a[i].t;j++);
n=0;
for(int k=i;k<j;k++){
if(a[k].u^a[k].t)sum[a[k].x]++;
if(a[k].v^a[k].t)sum[a[k].y]++;
}
ans+=(j-i)*(j-i-1ll)/2;
int u=a[i].t;
for(int v:to[u])if(v^fa[u])ans-=sum[v]*(sum[v]-1ll)/2;
}
for(int i=1,j;i<=m;i=j){
for(j=i;j<=m;j++){
if(a[j].t^a[i].t||a[j].x^a[i].x||a[j].y^a[i].y)break;
}
if(a[i].x&&a[i].y)ans+=(j-i)*(j-i-1ll)/2;
}
for(int i=1;i<=m;i++){
f[a[i].t]++;
if(a[i].u^a[i].t)g[a[i].x]--;
if(a[i].v^a[i].t)g[a[i].y]--;
}
for(int u=2;u<=::n;u++){
ans+=1ll*cnt[u]*(f[fa[u]]+g[u]);
}
}
}
int main(){
freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%d",&n);
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
to[u].push_back(v),to[v].push_back(u);
}
T::init();
int mm;
scanf("%d",&mm);
for(int u,v;mm--;){
scanf("%d%d",&u,&v);
if(u==v)cnt[u]++;
else a[++m]={u,v,LCA(u,v)};
}
Solve1::calc();
Solve2::calc();
Solve3::calc();
Solve4::calc();
cout<<ans;
return 0;
}