一个需要感性理解的树上算法 学习心得

题目描述

你现在有一颗 \(n\) 个点的树和 \(m\) 条由 \(x_i\)\(y_i\) ( \(1 \le x_i\ ,\ y_i \le n\) ) 的简单可重复路径。求有多少种方案选路径,使路径集的大小为 \(k\) ,且所有路径至少有一个公共点。对 \(10^9+7\) 取模。

题解

这道题需要感性理解做法。我们记一个路径左右端点的最近公共祖先为这个路径的 “最高点” ( lca ),那么对于任意一个点而言,若有 \(x\) 条路径经过,其中 \(y\) 条的最高点不是这个点,那么这个点就能贡献 \(C_{x}^{x+y}-C_x^{y}\) 种方案。
感性理解,就是两个路径的交点肯定是其中一条的 lca ,而有 \(C_{x}^y\) 种方案会在更高的点上被统计,所以要减去。

原题链接 - Super

/*  
  bug:1.初始化fac的范围不是1~n而是1~N 除非在输入n之后初始化  
      2.树剖的rev 不是 rev[++dfc]=x 而是 rev[x] = ++dfc.  
 */  
#include<bits/stdc++.h>  
using namespace std;  
#define N 2000005  
#define ll long long  
#define P 1000000007ll  
int thr[N],hig[N],cf[N];  
int n,m,k;  
int tot,first[N],nxt[N*2],to[N*2];  
int top[N],rev[N],dep[N],siz[N],fah[N],son[N],dfc;  
ll fac[N];  
ll ans;  
void add(int x,int y){  
    nxt[++tot]=first[x];  
    first[x]=tot;  
    to[tot]=y;  
    return;  
}  
ll qpow(ll a,ll b){  
    ll c=1;  
    while(b>0){  
        if(b&1)c=c*a%P;  
        b>>=1;  
        a=a*a%P;  
    }  
    return c;  
}  
ll inv(ll x){  
    return qpow(x,P-2);  
}  
ll C(ll n,ll m){  
    if(n<m)return 0;  
    return fac[n]*inv(fac[m])%P*inv(fac[n-m])%P;  
}  
template<typename A,typename B>  
void Add(A &x,B y){  
    x+=y;  
    if(x>=P)x-=P;  
    return;  
}  
void spread(int x){  
    siz[x]=1;  
    for(int e=first[x];e;e=nxt[e]){  
        int u=to[e];  
        if(siz[u])continue;  
        dep[u]=dep[x]+1;  
        fah[u]=x;  
        spread(u);  
        siz[x]+=siz[u];  
        if(siz[u]>siz[son[x]])son[x]=u;  
    }  
    return;  
}  
void chain(int x,int curtop){  
    rev[x]=++dfc;  
    top[x]=curtop;  
    if(!son[x])return;  
    chain(son[x],curtop);  
    for(int e=first[x];e;e=nxt[e]){  
        int u=to[e];  
        if(top[u])continue;  
        chain(u,u);  
    }  
    return;  
}  
int ancestor(int x,int y){  
    while(top[x]!=top[y]){  
        if(dep[top[x]]<dep[top[y]])y=fah[top[y]];  
        else x=fah[top[x]];  
    }  
    if(rev[x]>rev[y])return y;  
    return x;  
}  
void getans(int x,int fa){  
    thr[x]=cf[x];  
    for(int e=first[x];e;e=nxt[e]){  
        int u=to[e];  
        if(u==fa)continue;  
        getans(u,x);  
        Add(thr[x],thr[u]);  
    }  
    if(hig[x]==0)return;  
    Add(ans,(C(thr[x],k)-C(thr[x]-hig[x],k)+P)%P);  
    return;  
}  
void read(int &x){  
    char ch=getchar();  
    x=0;  
    while(!isdigit(ch))ch=getchar();  
    while(isdigit(ch)){  
        x=x*10+ch-'0';  
        ch=getchar();  
    }  
    return;  
}  
template<typename T,typename...Types>  
void read(T &x,Types &...args){  
    read(x);  
    read(args...);  
    return;  
}  
int main(){  
    freopen("desire.in","r",stdin);  
    freopen("desire.out","w",stdout);  
    int x,y;  
    fac[0]=dep[1]=1;  
    for(int i=1;i<N;++i){  
        fac[i]=fac[i-1]*i%P;  
    }  
    read(n,m,k);  
    for(int i=1;i<n;++i){  
        read(x,y);  
        add(x,y);  
        add(y,x);  
    }  
    spread(1);  
    chain(1,1);  
    for(int i=1;i<=m;++i){  
        read(x,y);  
        int lca=ancestor(x,y);  
        ++cf[x];  
        ++cf[y];  
        --cf[lca];  
        --cf[fah[lca]];  
        ++hig[lca];  
    }  
    getans(1,0);  
    printf("%lld",ans);  
    return 0;  
}
posted @ 2023-10-13 14:36  DZhearMins  阅读(9)  评论(0编辑  收藏  举报