【XSY3972】树与图(树形dp,树剖,分治NTT)
题面
题解
不难发现本题可以转化成以下题目: 给定一个 \(n\) 个点的有根树 ,你可以在树上选择 \(k\) 个点,满足对于任意两个点都不 互为祖先关系,且从根到每个叶子的路径上都恰好有一个被选择的点。求对于所有 \(i\in[1,n]\),求所有恰好选择 \(i\) 个点的方案数。
这显然可以树形 dp,设 \(f_{i,j}\) 表示以 \(i\) 为根的子树中恰好选择了 \(j\) 个节点的方案数(先不考虑选择顺序),合并就是树上背包。
发现这个树上背包可以用多项式乘法实现,即设 \(G_i(x)\) 表示 \(f_{i,j}\) 的生成函数,那么有:
(加了一个 \(x\) 是因为可以选择选自己,此时儿子子树内不能选)
这个转移显然可以用分治 NTT 实现,但时间过不去。
但是你发现你根本不需要求出所有点的 \(G(x)\),只需要求根节点的 \(G(x)\)。
两个 Subtask 给了我们启发:我们考虑对原树树剖,然后一条重链上的单独处理。
具体来说,设 \(F_u(x)\) 表示 \(u\) 的轻儿子的 \(G\) 的乘积,即:(记 \(son_u\) 为 \(u\) 的重儿子)
显然这个 \(F_u(x)\) 也是可以通过分治 NTT 得到的。
然后有:
需要注意的是,如果 \(u\) 没有重儿子(即 \(u\) 是叶子节点),那么赋值 \(F_u(x)=0\);如果 \(u\) 有重儿子但没有轻儿子,那么赋值 \(F_u(x)=1\)。
然后对于一条重链来说,设这条重链上的点由浅至深分别为 \(p_1,p_2,\cdots,p_k\)(显然 \(p_1\) 为链头,\(p_k\) 为叶子节点),不断代入 \(G_{p_i}(x)=F_u(x)G_{p_{i+1}}(x)+x\),有:
(由第一步跳到第二步是因为 \(F_{p_k}(x)=0\))
这个也可以通过分治 NTT 解决。
那么实现的部分就做完了。
接下来讲时间复杂度证明:
首先我们的时间主要都是耗在两种情况的分治 NTT 上面,所以我们下面只考虑这两种分治 NTT 所消耗的时间。
首先,假设某次分治 NTT 最后得到的式子是 \(l\) 次的,那么这次分治 NTT 所消耗的时间是 \(O(l\log^2 l)\) 级别的。
设 \(leaf_i\) 表示以 \(i\) 为根的子树内的叶子数目。那么根据 dp 的定义,\(G_u(x)\) 的次数最多不会超过 \(leaf_u\)。
我们先考虑第二种情况的分治 NTT,即求每条链头的 \(G(x)\) 的时间总和:(其中 \(top\) 是链头的集合)
考虑 \(\sum\limits_{u\in top} leaf_u\) 是什么级别的:由于每个叶子往上最多给 \(\log n\) 个链头的 \(leaf\) 贡献 \(1\),而最多有 \(n\) 个叶子,所以这个东西是 \(n\log n\) 级别的。
所以这种情况的总时间:
现在考虑第一种情况的分治 NTT:\(F_u(x)=\prod\limits_{(u,v)\atop v\neq son_u}G_v(x)\)。
\(F_u\) 的次数为轻儿子的 \(leaf\) 的总和。
我们还是总体考虑:
我们像第二种情况一样类似地考虑 \(\sum\limits_{u=1}^n\sum\limits_{(u,v)\atop v\neq son_u}leaf_v\) 的级别:由于每个叶子往上最多跳 \(\log n\) 条轻边,所以最多给 \(\log n\) 个轻儿子的 \(leaf\) 贡献 \(1\),所以这个东西也是 \(n\log n\) 级别的。
所以第二种情况的总时间:
所以算法的总时间复杂度是 \(O(n\log^3 n)\) 的,但由于树剖 \(\log\) 小和跑不满等原因可以跑过。
zjr 巨佬有优秀的 \(O(n\log^2 n)\) 做法,但蒟蒻没听懂(((
代码如下:
#include<bits/stdc++.h>
#define LN 19
#define N 100010
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
typedef vector<int> poly;
int tot;
int rev[N<<2],w[LN][N<<2][2];
poly F[N],G[N],f[N<<2],g[N<<2],q[N];
void init(int n)
{
int limit=1;
while(limit<=(n<<1)) limit<<=1;
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(mod-1)/len);
int ign=poww(gn,mod-2);
int g=1,ig=1;
for(int j=0;j<mid;g=mul(g,gn),ig=mul(ig,ign),j++)
w[bit][j][0]=g,w[bit][j][1]=ig;
}
}
void NTT(int *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,mod-2);
for(int i=0;i<limit;i++)
a[i]=mul(a[i],tmp);
}
}
poly polymul(poly a,poly b)
{
static int A[N<<2],B[N<<2];
int limit=1,siza=a.size(),sizb=b.size();
if((!siza)&&(!sizb))
{
poly ans(0);
return ans;
}
while(limit<siza+sizb) limit<<=1;
for(int i=0;i<siza;i++) A[i]=a[i];
for(int i=0;i<sizb;i++) B[i]=b[i];
NTT(A,limit,1),NTT(B,limit,1);
for(int i=0;i<limit;i++) A[i]=mul(A[i],B[i]);
NTT(A,limit,-1);
poly ans(siza+sizb-1);
for(int i=0;i<siza+sizb-1;i++) ans[i]=A[i];
for(int i=0;i<limit;i++) A[i]=B[i]=0;
return ans;
}
poly polyadd(poly a,poly b)
{
if(a.size()<b.size())
{
for(int i=0,size=a.size();i<size;i++) b[i]=add(b[i],a[i]);
return b;
}
else
{
for(int i=0,size=b.size();i<size;i++) a[i]=add(a[i],b[i]);
return a;
}
}
void solve1(int k,int l,int r)
{
if(l==r)
{
f[k]=q[l];
return;
}
int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
solve1(lc,l,mid),solve1(rc,mid+1,r);
f[k]=polymul(f[lc],f[rc]);
}
void solve2(int k,int l,int r)
{
if(l==r)
{
f[k]=g[k]=q[l];
return;
}
int mid=(l+r)>>1,lc=k<<1,rc=k<<1|1;
solve2(lc,l,mid),solve2(rc,mid+1,r);
f[k]=polymul(f[lc],f[rc]);
g[k]=polyadd(g[lc],polymul(f[lc],g[rc]));
}
int n;
int cnt,head[N],nxt[N<<1],to[N<<1];
int fa[N],d[N],size[N],son[N],leaf[N];
bool top[N];
void adde(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void dfs(int u)
{
size[u]=1;
if(!head[u]) leaf[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
d[v]=d[u]+1;
dfs(v);
leaf[u]+=leaf[v];
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
for(int i=head[u];i;i=nxt[i])
if(to[i]!=son[u]) top[to[i]]=1;
}
void dfs1(int u)
{
for(int i=head[u];i;i=nxt[i]) dfs1(to[i]);
tot=0;
for(int i=head[u];i;i=nxt[i])
if(to[i]!=son[u]) q[++tot]=G[to[i]];
// for(int i=head[u];i;i=nxt[i]) if(to[i]!=son[u]) q[++tot]=G[to[i]];
if(tot)
{
solve1(1,1,tot);
F[u]=f[1];
}
else F[u].push_back(1);
// if(tot){solve1(1,1,tot);F[u]=f[1];}
if(top[u])
{
tot=0;
int now=u;
while(son[now])
{
q[++tot]=F[now];
now=son[now];
}
// while(son[now]){q[++tot]=F[now];now=son[now];}
if(tot)
{
solve2(1,1,tot);
G[u]=g[1];
}
if(G[u].size()) G[u][0]=add(G[u][0],1);
else G[u].push_back(1);
G[u].push_back(114514);
for(int i=G[u].size()-1;i>=1;i--) G[u][i]=G[u][i-1];
G[u][0]=0;
}
}
int main()
{
// freopen("tree3.in","r",stdin);
// freopen("tree3.out","w",stdout);
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
n=read();
init(n);
for(int i=2;i<=n;i++)
{
fa[i]=read();
adde(fa[i],i);
}
dfs(1);
top[1]=1;
dfs1(1);
int ans=0,fac=1;
for(int i=1;i<=n;i++)
{
fac=mul(fac,i);
if(G[1].size()>i) ans=add(ans,mul(read(),mul(G[1][i],fac)));
else break;
}
printf("%d\n",ans);
return 0;
}
/*
3
1 1
3 1 7
*/
/*
6
1 2 1 2 1
3 3 3 2 3 2
*/
/*
3
1 2
114514 1919810 1919810
*/