BSOJ6387题解
算是刷新了我对树上问题的认知
首先第一问随便做一个 \(O(nk)\) 的 DP 就可以草过去,考虑第二问。
我们将问题分为两个部分:走儿子边的答案和走父亲边的答案。最后拼接一下就好了。
设 \(fd[u][k]\) 是走儿子边且距离不超过 \(k\) 的节点数量,\(fu[u][k]\) 是走父亲边的答案;\(gd[u][k]\) 是走儿子边的拥挤程度,\(gu[u][k]\) 同理。
这几个转移起来相当简单,不再赘述。可以做到 \(O(nk\log n)\) 或 \(O(n(k+\log n))\)。
#include<cstdio>
typedef unsigned ui;
const ui M=1e5+5,K=15,mod=1e9+7;
ui n,k,cnt,h[M],f[M],ans[M],fd[M][K],fu[M][K],gd[M][K],gu[M][K];
struct Edge{
ui v,nx;
}e[M<<1];
inline void Add(const ui&u,const ui&v){
e[++cnt]=(Edge){v,h[u]};h[u]=cnt;
e[++cnt]=(Edge){u,h[v]};h[v]=cnt;
}
inline void swap(ui&a,ui&b){
ui c=a;a=b;b=c;
}
inline ui pow(ui a,ui b){
ui ans(1);for(;b;b>>=1,a=1ull*a*a%mod)if(b&1)ans=1ull*ans*a%mod;return ans;
}
inline void init(const ui&u){
for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u])f[v]=u,init(v);
}
inline void DFS1(const ui&u){
for(ui i=0;i<=k;++i)fd[u][i]=gd[u][i]=1;
for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u]){
DFS1(v);
for(ui i=1;i<=k;++i){
fd[u][i]+=fd[v][i-1];gd[u][i]=1ull*gd[u][i]*gd[v][i-1]%mod;
}
}
for(ui i=0;i<=k;++i)gd[u][i]=1ull*gd[u][i]*fd[u][i]%mod;
}
inline void DFS2(const ui&u){
static ui t[K],inv[K];inv[1]=1;
for(ui i=0;i<=k;++i)fu[u][i]=gu[u][i]=1;
if(u!=1){
++fu[u][1];
gu[u][1]=1ull*gu[f[u]][0]*gd[f[u]][0]%mod*fu[u][1]%mod;
for(ui i=2;i<=k;++i){
const ui&sz1=fu[f[u]][i-1],&sz2=fd[f[u]][i-1],&sz3=fd[u][i-2];
fu[u][i]=sz1+sz2-sz3;
gu[u][i]=1ull*gu[f[u]][i-1]*gd[f[u]][i-1]%mod*(fu[u][i]-1)%mod*fu[u][i]%mod;
t[i]=1ull*gd[u][i-2]*sz1%mod*sz2%mod;
inv[i]=1ull*inv[i-1]*t[i]%mod;
}
inv[k]=pow(inv[k],mod-2);
for(ui i=k;i>1;--i)swap(inv[i],inv[i-1]),inv[i]=1ull*inv[i]*inv[i-1]%mod,inv[i-1]=1ull*inv[i-1]*t[i]%mod;
for(ui i=2;i<=k;++i)gu[u][i]=1ull*gu[u][i]*inv[i]%mod;
}
for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u])DFS2(v);
}
signed main(){
scanf("%u%u",&n,&k);
for(ui i=1;i<n;++i){
ui u,v;scanf("%u%u",&u,&v);
Add(u,v);
}
init(1);DFS1(1);DFS2(1);
for(ui u=1;u<=n;++u){
const ui&sz1=fd[u][k],sz2=fu[u][k];
ans[u]=1ull*gd[u][k]*gu[u][k]%mod*pow(1ull*sz1*sz2%mod,mod-2)%mod*(sz1+sz2-1)%mod;
printf("%u ",sz1+sz2-1);
}
printf("\n");
for(ui u=1;u<=n;++u)printf("%u ",ans[u]);
}