NOIP 模拟 $52\; \rm 路径$
题解 \(by\;zj\varphi\)
本质上可以理解为求长度为 \(x\) 的路径有多少条,最后 \(k\) 次方即可。
考虑点分治子树合并,设 \(ans_x\) 表示答案中长度为 \(x\) 的路径有多少条,那么:
\[ans_x=\sum_{i=0}^xdepa_i*depb_{x-i}
\]
其中 \(depa_i\) 表示已合并的子树中到分治中心长度为 \(i\) 的点个数,\(depb\) 表示当前子树。
发现这是个卷积形式,且模数是 \(998244353\),直接 \(\rm NTT\) 卷一下,但是会被卡常,用 unsigned long long
优化一下 \(\rm NTT\) 即可。
复杂度 \(\mathcal O\rm (nlog^2n)\)
Code
#include<bits/stdc++.h>
#define Re register
#define ri Re signed
#define pd(i) ++i
#define bq(i) --i
namespace IO{
char buf[1<<21],*p1=buf,*p2=buf;
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?(-1):*p1++
struct nanfeng_stream{
template<typename T>inline nanfeng_stream &operator>>(T &x) {
Re bool f=false;x=0;Re char ch=gc();
while(!isdigit(ch)) f|=ch=='-',ch=gc();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=gc();
return x=f?-x:x,*this;
}
}cin;
}
using IO::cin;
namespace nanfeng{
#define FI FILE *IN
#define FO FILE *OUT
template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
using ll=long long;
using ull=unsigned long long;
static const int N=1<<21,MOD=998244353;
struct edge{int v,nxt;}e[N<<1];
int first[N],siz[N],dep[N],R[N],n,k,t=1,cnt,pos,cmp,mx;
bool G[N];
ll w1[N],w2[N],ans[N],as;
ull a[N],b[N],c[N],d[N];
auto add=[](int u,int v) {
e[t].v=v,e[t].nxt=first[u],first[u]=t++;
e[t].v=u,e[t].nxt=first[v],first[v]=t++;
};
auto MD=[](ll x) {return x>=MOD?x-MOD:x;};
auto fpow=[](ll x,int y) {
ll res=1;
while(y) {
if (y&1) res=res*x%MOD;
x=x*x%MOD;
y>>=1;
}
return res;
};
void dfs_find(int x,int fa) {
siz[x]=1;
int GS(0);
for (ri i(first[x]),v;i;i=e[i].nxt) {
if ((v=e[i].v)==fa||G[v]) continue;
dfs_find(v,x);
GS=cmax(siz[v],GS);
siz[x]+=siz[v];
}
GS=cmax(GS,cmp-siz[x]);
if (GS<cnt) cnt=GS,pos=x;
}
void dfs_solve(int x,int fa) {
siz[x]=1;
for (ri i(first[x]),v;i;i=e[i].nxt) {
if ((v=e[i].v)==fa||G[v]) continue;
dep[v]=dep[x]+1;
++b[dep[v]];
mx=cmax(mx,dep[v]);
dfs_solve(v,x);
siz[x]+=siz[v];
}
}
void solve(int x,int S) {
cmp=cnt=S;
dfs_find(x,0);
int np;
G[np=pos]=true;
int mxp=0;
a[0]=1;
for (ri i(first[np]),v;i;i=e[i].nxt) {
if (G[v=e[i].v]) continue;
mx=dep[v]=1;
b[1]=1;
dfs_solve(v,np);
auto calc=[mxp]() {
int st=1,len=0;
while(st<=mxp+mx+2) st<<=1,++len;
int inv=fpow(st,MOD-2);
w1[1]=fpow(3,(MOD-1)/st),w2[1]=fpow(w1[1],MOD-2);
for (ri i(2);i<st;pd(i)) w1[i]=w1[i-1]*w1[1]%MOD,w2[i]=w2[i-1]*w2[1]%MOD;
for (ri i(0);i<st;pd(i)) R[i]=(R[i>>1]>>1)|((i&1)<<(len-1));
auto NTT1=[st](ull *a) {
for (ri i(0);i<st;pd(i)) if (R[i]>i) std::swap(a[R[i]],a[i]);
for (ri t(st>>1),d(1);d<st;t>>=1,d<<=1)
for (ri i(0);i<st;i+=d<<1)
for (ri j(0);j<d;pd(j)) {
const ll tmp=w1[t*j]*a[i+j+d]%MOD;
a[i+j+d]=a[i+j]-tmp+MOD;
a[i+j]=a[i+j]+tmp;
}
for (ri i(0);i<st;pd(i)) a[i]%=MOD;
};
auto NTT2=[st,inv](ull *a) {
for (ri i(0);i<st;pd(i)) if (R[i]>i) std::swap(a[R[i]],a[i]);
for (ri t(st>>1),d(1);d<st;t>>=1,d<<=1)
for (ri i(0);i<st;i+=d<<1)
for (ri j(0);j<d;pd(j)) {
const ll tmp=w2[t*j]*a[i+j+d]%MOD;
a[i+j+d]=a[i+j]-tmp+MOD;
a[i+j]=a[i+j]+tmp;
}
for (ri i(0);i<st;pd(i)) a[i]=a[i]%MOD*inv%MOD;
};
memcpy(c,a,sizeof(ull)*(mxp+1));
memset(c+mxp+1,0,sizeof(ull)*(st-mxp));
memcpy(d,b,sizeof(ull)*(mx+1));
memset(d+mx+1,0,sizeof(ull)*(st-mx));
NTT1(c),NTT1(d);
for (ri i(0);i<st;pd(i)) c[i]=c[i]*d[i]%MOD;
NTT2(c);
for (ri i(1);i<st;pd(i)) ans[i]=ans[i]+c[i];
};
calc();
mxp=cmax(mxp,mx);
for (ri j(1);j<=mx;pd(j)) a[j]+=b[j],b[j]=0;
}
memset(a,0,sizeof(ll)*(mxp+1));
for (ri i(first[np]),v;i;i=e[i].nxt) {
if (G[v=e[i].v]) continue;
solve(v,siz[v]);
}
}
inline int main() {
// FI=freopen("nanfeng.in","r",stdin);
// FO=freopen("nanfeng.out","w",stdout);
cin >> n >> k;
for (ri i(1),u,v;i<n;pd(i)) cin >> u >> v,add(u,v);
w1[0]=w2[0]=1;
solve(1,n);
for (ri i(1);i<n;pd(i)) as=MD(as+ans[i]%MOD*fpow(i,k)%MOD);
printf("%lld\n",as);
return 0;
}
}
int main() {return nanfeng::main();}