一个需要感性理解的树上算法 学习心得
题目描述
你现在有一颗 \(n\) 个点的树和 \(m\) 条由 \(x_i\) 到 \(y_i\) ( \(1 \le x_i\ ,\ y_i \le n\) ) 的简单可重复路径。求有多少种方案选路径,使路径集的大小为 \(k\) ,且所有路径至少有一个公共点。对 \(10^9+7\) 取模。
题解
这道题需要感性理解做法。我们记一个路径左右端点的最近公共祖先为这个路径的 “最高点” ( lca ),那么对于任意一个点而言,若有 \(x\) 条路径经过,其中 \(y\) 条的最高点不是这个点,那么这个点就能贡献 \(C_{x}^{x+y}-C_x^{y}\) 种方案。
感性理解,就是两个路径的交点肯定是其中一条的 lca ,而有 \(C_{x}^y\) 种方案会在更高的点上被统计,所以要减去。
/*
bug:1.初始化fac的范围不是1~n而是1~N 除非在输入n之后初始化
2.树剖的rev 不是 rev[++dfc]=x 而是 rev[x] = ++dfc.
*/
#include<bits/stdc++.h>
using namespace std;
#define N 2000005
#define ll long long
#define P 1000000007ll
int thr[N],hig[N],cf[N];
int n,m,k;
int tot,first[N],nxt[N*2],to[N*2];
int top[N],rev[N],dep[N],siz[N],fah[N],son[N],dfc;
ll fac[N];
ll ans;
void add(int x,int y){
nxt[++tot]=first[x];
first[x]=tot;
to[tot]=y;
return;
}
ll qpow(ll a,ll b){
ll c=1;
while(b>0){
if(b&1)c=c*a%P;
b>>=1;
a=a*a%P;
}
return c;
}
ll inv(ll x){
return qpow(x,P-2);
}
ll C(ll n,ll m){
if(n<m)return 0;
return fac[n]*inv(fac[m])%P*inv(fac[n-m])%P;
}
template<typename A,typename B>
void Add(A &x,B y){
x+=y;
if(x>=P)x-=P;
return;
}
void spread(int x){
siz[x]=1;
for(int e=first[x];e;e=nxt[e]){
int u=to[e];
if(siz[u])continue;
dep[u]=dep[x]+1;
fah[u]=x;
spread(u);
siz[x]+=siz[u];
if(siz[u]>siz[son[x]])son[x]=u;
}
return;
}
void chain(int x,int curtop){
rev[x]=++dfc;
top[x]=curtop;
if(!son[x])return;
chain(son[x],curtop);
for(int e=first[x];e;e=nxt[e]){
int u=to[e];
if(top[u])continue;
chain(u,u);
}
return;
}
int ancestor(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])y=fah[top[y]];
else x=fah[top[x]];
}
if(rev[x]>rev[y])return y;
return x;
}
void getans(int x,int fa){
thr[x]=cf[x];
for(int e=first[x];e;e=nxt[e]){
int u=to[e];
if(u==fa)continue;
getans(u,x);
Add(thr[x],thr[u]);
}
if(hig[x]==0)return;
Add(ans,(C(thr[x],k)-C(thr[x]-hig[x],k)+P)%P);
return;
}
void read(int &x){
char ch=getchar();
x=0;
while(!isdigit(ch))ch=getchar();
while(isdigit(ch)){
x=x*10+ch-'0';
ch=getchar();
}
return;
}
template<typename T,typename...Types>
void read(T &x,Types &...args){
read(x);
read(args...);
return;
}
int main(){
freopen("desire.in","r",stdin);
freopen("desire.out","w",stdout);
int x,y;
fac[0]=dep[1]=1;
for(int i=1;i<N;++i){
fac[i]=fac[i-1]*i%P;
}
read(n,m,k);
for(int i=1;i<n;++i){
read(x,y);
add(x,y);
add(y,x);
}
spread(1);
chain(1,1);
for(int i=1;i<=m;++i){
read(x,y);
int lca=ancestor(x,y);
++cf[x];
++cf[y];
--cf[lca];
--cf[fah[lca]];
++hig[lca];
}
getans(1,0);
printf("%lld",ans);
return 0;
}