「题解」洛谷 P5327 [ZJOI2019]语言
简单做法:树剖一下然后变成在 \(dfn\times dfn\) 平面上 \(\log^2\) 个矩形覆盖,查询非零位置个数,扫描线一下是 \(\mathcal{O}(n\log^3n)\).
稍微卡卡常数甚至不需要怎么卡就过了。
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<ctime>
#define pb emplace_back
#define mp std::make_pair
#define fi first
#define se second
#define dbg(x) cerr<<"In Line "<< __LINE__<<" the "<<#x<<" = "<<x<<'\n';
#define dpi(x,y) cerr<<"In Line "<<__LINE__<<" the "<<#x<<" = "<<x<<" ; "<<"the "<<#y<<" = "<<y<<'\n';
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int>pii;
typedef pair<ll,int>pli;
typedef pair<ll,ll>pll;
typedef vector<int>vi;
typedef vector<ll>vll;
typedef vector<pii>vpii;
template<typename T>T cmax(T &x, T y){return x=x>y?x:y;}
template<typename T>T cmin(T &x, T y){return x=x<y?x:y;}
template<typename T>
T &read(T &r){
r=0;bool w=0;char ch=getchar();
while(ch<'0'||ch>'9')w=ch=='-'?1:0,ch=getchar();
while(ch>='0'&&ch<='9')r=r*10+(ch^48),ch=getchar();
return r=w?-r:r;
}
template<typename T1,typename... T2>
void read(T1 &x, T2& ...y){ read(x); read(y...); }
const int N=100010;
int n,m,c[N];
//hld
int siz[N],dfn[N],top[N],dep[N],fa[N],dft,son[N];
vi eg[N];
void dfs1(int x,int f){
fa[x]=f;siz[x]=1;dep[x]=dep[f]+1;
for(auto v:eg[x])if(v!=f){
dfs1(v,x);
siz[x]+=siz[v];
if(siz[v]>siz[son[x]])son[x]=v;
}
}
void dfs2(int x,int t){
top[x]=t;dfn[x]=++dft;
if(son[x])dfs2(son[x],t);
for(auto v:eg[x])if(v!=fa[x]&&v!=son[x]){
dfs2(v,v);
}
}
//Tree
int lct;
struct Line{
int l,r,h,v;
bool operator<(const Line &y)const{
return h<y.h;
}
}li[N*300];
inline void PushLine(int l,int r,int h,int v){
c[l]++;c[r+1]--;
li[++lct]={l,r,h,v};
}
int trct,mn[N<<1],mct[N<<1],tag[N<<1],ls[N<<1],rs[N<<1];
int build(int l,int r){
int x=++trct,mid=(l+r)>>1;mn[x]=0;mct[x]=r-l+1;
if(l==r)return x;
ls[x]=build(l,mid);rs[x]=build(mid+1,r);
return x;
}
inline void pushup(int x){
mn[x]=min(mn[ls[x]],mn[rs[x]]);mct[x]=0;
if(mn[x]==mn[ls[x]])mct[x]+=mct[ls[x]];
if(mn[x]==mn[rs[x]])mct[x]+=mct[rs[x]];
}
inline void upd(int x,int v){
mn[x]+=v;tag[x]+=v;
}
inline void pushdown(int x){
if(tag[x]){
upd(ls[x],tag[x]);
upd(rs[x],tag[x]);
tag[x]=0;
}
}
void modify(int x,int tl,int tr,int l,int r,int v){
if(tl>=l&&tr<=r){
upd(x,v);
return ;
}
int mid=(tl+tr)>>1;
pushdown(x);
if(mid>=l)modify(ls[x],tl,mid,l,r,v);
if(mid<r)modify(rs[x],mid+1,tr,l,r,v);
pushup(x);
}
int query(int x,int tl,int tr,int l,int r){
if(tl>=l&&tr<=r)return !mn[x]?mct[x]:0;
int mid=(tl+tr)>>1,s=0;
pushdown(x);
if(mid>=l)s+=query(ls[x],tl,mid,l,r);
if(mid<r)s+=query(rs[x],mid+1,tr,l,r);
pushup(x);
return s;
}
//
void qwdasd(int x,int y){
vpii a;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]){
a.pb(mp(dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
else{
a.pb(mp(dfn[top[y]],dfn[y]));
y=fa[top[y]];
}
}
if(dep[x]<=dep[y])a.pb(mp(dfn[x],dfn[y]));
else a.pb(mp(dfn[y],dfn[x]));
int len=a.size();
for(int i=0;i<len;i++)
for(int j=0;j<len;j++){
if(a[i].fi>=a[j].se)continue;
PushLine(a[i].fi,a[i].se,a[j].fi,1);
if(a[j].se!=n)PushLine(a[i].fi,a[i].se,a[j].se+1,-1);
}
}
signed main(){
read(n,m);
for(int i=1,u,v;i<n;i++){
read(u,v);
eg[u].pb(v);eg[v].pb(u);
}
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=m;i++){
int x,y;read(x,y);
qwdasd(x,y);
}
sort(li+1,li+lct+1);
build(1,n);
int p=1;
ll ans=0;
for(int i=1;i<=n;i++){
while(p<=lct&&li[p].h==i){
modify(1,1,n,li[p].l,li[p].r,li[p].v);
++p;
}
if(!mn[1]&&i>1)ans+=query(1,1,n,1,i-1);
}
cout << 1ll*n*(n-1)/2-ans << '\n';
#ifdef do_while_true
cerr<<'\n'<<"Time:"<<clock()<<" ms"<<'\n';
#endif
return 0;
}