CF1010F Tree
CF1010F Tree
重链剖分+\(NTT\)
直接考虑\(A_i\)并不容易,我们可以进行差分,具体来说,令:
\[B_i=A_i-\sum_{j \in son_i} A_j
\]
这样我们只需要保证\(B_i>0\)即可,同时\(\sum B_i = x\)。
根据插板法,如果有\(k\)个点,那么它的贡献就是\({x+k-1 \choose k-1}\)。
接下来的任务就是对于每个\(k\),计算出方案数。
对于每个节点,建立答案的生成函数。
如果有两个子节点:
\[F_u(x)=x F_{v1} (x) F_{v2} (x)+1
\]
一个子节点:
\[F_u(x)=xF_v(x)+1
\]
无子节点:
\[F_u(x)=x+1
\]
利用重链剖分进行优化,进行链分治,首先计算出所有轻儿子的生成函数。
设\(F_0=x\),从下至上的轻儿子生成函数乘上\(x\)为\(F_1,F_2,\cdots,F_m\)(无轻儿子则\(F_i=x\))。
那么我们可以计算出重链顶端的生成函数:
\[F=(((F_0+1)F_1+1)F_2+1)\cdots +1\\
=F_0F_1F_2\cdots+F_1F_2\cdots+\cdots+1
\]
令\(S=F_0F_1F_2\cdots+F_1F_2\cdots+\cdots+1,T=F_0F_1F_2\cdots\)。
进行分治计算,计算出左右两部分的答案\(S_0,T_0,S_1,T_1\)。
则:
\[S=(S_0-1)T_1+S_1\\
T=T_0T_1
\]
考虑一下复杂度上限,对于一颗大小为\(t\)的子树,我们会进行分治,分治合并时需要利用卷积,如果两个长度为\(n,m\)的多项式卷积复杂度近似看做\(O((n+m)\log x)\)(带了\(\log\)之后\(x\)是多少并不重要),那么我们单独考虑贡献,把分治当成一颗二叉树,对于每个叶子节点,若其大小为\(c\),那么它一共在\(\log t\)个节点有贡献,每次贡献看成\(c \log x\),总共的贡献就是\(c \log^2 x\),所以一颗大小为\(t\)的子树贡献就是\(O(t \log^2 t)\)。
根据重链剖分轻子树大小总和为\(n \log n\),得出时间复杂度为\(O(n \log^3 n)\)。
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#define N 100005
#define ll long long
#define IT vector<int> :: iterator
using namespace std;
const int p=998244353;
int s,l,G[2][25],rev[N << 1];
void Add(int &x,int y)
{
x=(x+y)%p;
}
void Del(int &x,int y)
{
x=(x-y)%p;
}
void Mul(int &x,int y)
{
x=(ll)x*y%p;
}
int add(int x,int y)
{
return (x+y)%p;
}
int del(int x,int y)
{
return (x-y)%p;
}
int mul(int x,int y)
{
return (ll)x*y%p;
}
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y & 1)
Mul(ans,x);
Mul(x,x);
y >>=1;
}
return ans;
}
void Pre()
{
G[0][23]=ksm(3,(p-1)/(1 << 23));
G[1][23]=ksm(G[0][23],p-2);
for (int i=22;i;--i)
{
G[0][i]=mul(G[0][i+1],G[0][i+1]);
G[1][i]=mul(G[1][i+1],G[1][i+1]);
}
}
void solve(int n)
{
s=1,l=0;
while (s<n)
s <<=1,++l;
for (int i=0;i<s;++i)
rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << l-1);
}
struct Poly
{
int n;
vector<int>a;
int& operator [] (int x)
{
return a[x];
}
void read(int zn)
{
n=zn,a.clear();
for (int i=0;i<n;++i)
a.push_back(0),scanf("%d",&a[i]);
}
void print()
{
puts("--------------------");
printf("Len: %d\n",n);
for (int i=0;i<n;++i)
printf("%d ",a[i]);
putchar('\n');
puts("--------------------");
}
void clean()
{
n=0,a.clear();
}
void reuse(int zn)
{
n=zn,a.clear();
for (int i=0;i<n;++i)
a.push_back(0);
}
void extend(int S=s)
{
int t=S-a.size();
for (int i=1;i<=t;++i)
a.push_back(0);
}
void rollback(int S)
{
int t=a.size()-S;
for (int i=0;i<t;++i)
a.pop_back();
}
void NTT(int t)
{
for (int i=0;i<s;++i)
if (i<rev[i])
swap(a[i],a[rev[i]]);
for (int mid=1,o=1;mid<s;mid <<=1,++o)
for (int j=0;j<s;j+=mid << 1)
{
int g=1;
for (int k=0;k<mid;++k,Mul(g,G[t][o]))
{
int x=a[j+k],y=mul(g,a[j+k+mid]);
a[j+k]=add(x,y);
a[j+k+mid]=del(x,y);
}
}
}
void minv(int S=s)
{
int t=ksm(S,p-2);
for (int i=0;i<s;++i)
Mul(a[i],t);
}
};
Poly operator + (Poly f,Poly g)
{
int n=max(f.n,g.n);
f.n=n;
f.extend(n),g.extend(n);
for (int i=0;i<n;++i)
Add(f[i],g[i]);
return f;
}
void operator += (Poly &f,Poly &g)
{
int n=max(f.n,g.n);
f.n=n;
f.extend(n),g.extend(n);
for (int i=0;i<n;++i)
Add(f[i],g[i]);
g.rollback(g.n);
}
Poly operator * (Poly f,Poly g)
{
int n=f.n,m=g.n;
solve(n+m);
f.extend(),g.extend();
f.NTT(0),g.NTT(0);
for (int i=0;i<s;++i)
Mul(f[i],g[i]);
f.NTT(1),f.minv();
f.n=n+m-1,f.rollback(f.n);
return f;
}
void operator *= (Poly &f,Poly &g)
{
int n=f.n,m=g.n;
solve(n+m);
f.extend(),g.extend();
f.NTT(0),g.NTT(0);
for (int i=0;i<s;++i)
Mul(f[i],g[i]);
f.NTT(1),f.minv();
f.n=n+m-1,f.rollback(f.n),g.rollback(m);
}
void polyswap(Poly &f,Poly &g)
{
f.a.swap(g.a),swap(f.n,g.n);
}
int n,x,y,ans;
ll X;
struct edge
{
int nxt,v;
edge () {}
edge (int Nxt,int V):nxt(Nxt),v(V) {}
}e[N << 1];
int tot,fr[N],sz[N],son[N],fa[N];
vector<int>H[N];
void link(int x,int y)
{
++tot;
e[tot]=edge(fr[x],y),fr[x]=tot;
}
void dfs(int u)
{
sz[u]=1;
int mx=-1;
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==fa[u])
continue;
fa[v]=u;
dfs(v);
sz[u]+=sz[v];
if (sz[v]>mx)
mx=sz[v],son[u]=v;
}
}
void dfs2(int u,int tp)
{
H[tp].push_back(u);
if (!son[u])
return;
dfs2(son[u],tp);
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==fa[u] || v==son[u])
continue;
dfs2(v,v);
}
}
#define ls (p << 1)
#define rs (p << 1 | 1)
Poly Z,Z2,F[N],S[N << 2],T[N << 2];
void modify(int p,int l,int r,int x,Poly &a)
{
if (l==r)
{
polyswap(S[p],a);
a.clean();
T[p]=S[p];
++S[p][0];
return;
}
int mid=(l+r) >> 1;
if (x<=mid)
modify(ls,l,mid,x,a); else
modify(rs,mid+1,r,x,a);
}
void calc(int p,int l,int r)
{
if (l==r)
return;
int mid=(l+r) >> 1;
calc(ls,l,mid);
calc(rs,mid+1,r);
--S[ls][0];
S[p]=S[ls]*T[rs]+S[rs];
T[p]=T[ls]*T[rs];
S[ls].clean(),T[ls].clean();
S[rs].clean(),T[rs].clean();
}
void Solve(int u)
{
if (!son[u])
{
F[u].reuse(2);
F[u][0]=F[u][1]=1;
return;
}
int cnt=0;
for (IT it=H[u].begin();it!=H[u].end();++it)
{
int v=*it;
++cnt;
for (int i=fr[v];i;i=e[i].nxt)
{
int v2=e[i].v;
if (v2==fa[v] || v2==son[v])
continue;
Solve(v2);
}
}
--cnt;
Z2=Z;
modify(1,0,cnt,0,Z2);
reverse(H[u].begin(),H[u].end());
int rct=0;
for (IT it=H[u].begin()+1;it!=H[u].end();++it)
{
++rct;
int v=*it;
bool flag=false;
for (int i=fr[v];i;i=e[i].nxt)
{
int v2=e[i].v;
if (v2==fa[v] || v2==son[v])
continue;
flag=true;
reverse(F[v2].a.begin(),F[v2].a.end());
F[v2].a.push_back(0),++F[v2].n;
reverse(F[v2].a.begin(),F[v2].a.end());
modify(1,0,cnt,rct,F[v2]);
}
if (!flag)
Z2=Z,modify(1,0,cnt,rct,Z2);
}
calc(1,0,cnt);
polyswap(F[u],S[1]);
S[1].clean(),T[1].clean();
}
int main()
{
Pre();
scanf("%d%lld",&n,&X);
int zx=X%p;
for (int i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
dfs(1);
dfs2(1,1);
Z.reuse(2),Z[1]=1;
Solve(1);
int z1=1,z2=1;
for (int i=1;i<=n;++i)
{
Add(ans,mul(mul(z1,z2),F[1][i]));
Mul(z1,add(zx,i));
Mul(z2,ksm(i,p-2));
}
ans=(ans%p+p)%p;
printf("%d\n",ans);
return 0;
}