模拟赛11.7 解题报告
额……好久没有更新这个系列了……
T3
题意:给出一棵树 \(T_0\),大小为 \(n\)。
有 \(k\) 次操作,第 \(i\) 次在树 \(T_{i-1}\) 中选择一条边切断成两部分,然后选择其中一部分作为 \(T_i\)。
给出 \(a_{1...k}\),需要满足 \(|T_i|=a_i\)。
\(1\le n\le 5000,\space 1\le k\le 6\)
考虑每次切边,分成两部分:
-
子树
-
非子树
称选择子树的切边操作为一操作,选择非子树的切边操作为二操作。
我们需要充分利用 \(a\) 的信息,第 \(i\) 次操作时的树大小为 \(a_{i-1}\)。
对于第 \(i\) 次操作,若其是二操作,不难发现切的边对应的子树大小应是 \(a_{i-1}-a_i\)。我们不需要考虑切完后子树外大小为 \(a_i\),只需要满足切的这条边下面的子树大小是 \(a_{i-1}-a_i\),这样限制会宽松很多。
若其是一操作,选的子树大小就是 \(a_i\)。
先不考虑一操作。\(k\) 很小,往状压方向想。设 \(g[u,S]\) 表示 \(u\) 子树内完成了 \(S\) 集合内的操作,且都保留了非子树部分(二操作)的方案数。
显然 \(g[u,\emptyset]=1\)。转移时只需要枚举子集,暴力合并即可。
考虑单个贡献,即:若 \(u\) 不是根,我们可以对 \((fa_u,u)\) 这条边进行操作。对于 \(g[u,S]\),选择 \(S\) 中最大的操作编号 \(x\),判断在其他操作的影响下,\(u\) 子树大小是否为 \(a_{x-1}-a_x\),若满足则转移 \(g[u,S]\leftarrow g[u,S-\{x\}]\)。
考虑一操作。
不难发现,每次的一操作会不断往子树方向走,所有一操作的边应在一条链上且操作时间单调递增。
设 \(f[u,S,i]\) 表示 \(u\) 子树内完成了 \(S\) 集合内的操作,其中最晚的那次一操作是第 \(i\) 次操作,的方案数。
考虑不操作 \((fa_u,u)\) 这条边。那么我们合并所有儿子子树的信息,最多有一个儿子满足其子树内有一操作,其他儿子按 \(g\) 的合并方法合并,那个儿子也不难转移。
如果操作 \((fa_u,u)\) 这条边,枚举是第 \(x\) 次操作,那么子树内其他一操作的时间 \(i\) 都必须 \(<x\)。用 \(g[u,S-\{x\}],f[u,S-\{x\},i]\) 向 \(f[u,S,x]\) 贡献即可。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
using namespace std;
const ll maxn=5010,mod=998244353;
ll n,u,v,head[maxn],tot,m,a[maxn],f[maxn][1<<6][8],g[maxn][1<<6],d[1<<6],siz[maxn],mxbit[1<<6],ans;
ll L[1<<6][8],sum[1<<6];
struct edge
{
ll v,nxt;
}e[maxn<<1];
void ins(ll u,ll v)
{
e[++tot]=(edge){v,head[u]};
head[u]=tot;
}
void dfs(ll u,ll fa)
{
g[u][0]=1; siz[u]=1;
for(ll i=head[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(v==fa) continue;
dfs(v,u); siz[u]+=siz[v];
}
for(ll i=head[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(v==fa) continue;
for(ll S=(1<<m)-1;S;S--)
{
for(ll j=1;j<=m;j++)
if(S&(1<<j-1))
{
for(ll T=S;T;T=(T-1)&S)
{
if(T&(1<<j-1))
{
if(mxbit[S^T]<j)
{
f[u][S][j]=(f[u][S][j]+g[u][S^T]*f[v][T][j])%mod;
}
}
if(mxbit[T]>=j-1) continue;
f[u][S][j]=(f[u][S][j]+f[u][S^T][j]*g[v][T])%mod;
}
}
}
for(ll S=(1<<m)-1;S;S--)
{
for(ll T=S;T;T=(T-1)&S)
{
g[u][S]=(g[u][S]+g[u][S^T]*g[v][T])%mod;
}
}
}
if(u==1) return;
memset(L,0,sizeof L);
for(ll i=1;i<=m;i++)
{
for(ll S=(1<<m)-1;S;S--)
if((S&(1<<i-1))&&(a[i]==siz[u]-sum[S&((1<<i-1)-1)]))
{
L[S][i]=(L[S][i]+g[u][S^(1<<i-1)])%mod;
for(ll j=i+1;j<=m;j++)
if(S&(1<<j-1))
{
L[S][i]=(L[S][i]+f[u][S^(1<<i-1)][j])%mod;
}
}
}
for(ll S=(1<<m);S;S--)
{
ll p=mxbit[S]+1;
if(a[p-1]-a[p]==siz[u]-sum[S^(1<<p-1)]) g[u][S]=(g[u][S]+g[u][S^(1<<p-1)])%mod;
}
for(ll S=0;S<(1<<m);S++)
for(ll i=1;i<=m;i++)
f[u][S][i]=(f[u][S][i]+L[S][i])%mod;//, printf("f[%lld,%lld,%lld] = %lld\n",u,S,i,f[u][S][i]);
}
int main()
{
freopen("lone.in","r",stdin);
freopen("lone.out","w",stdout);
scanf("%lld",&n);
for(ll i=1;i<n;i++)
{
scanf("%lld%lld",&u,&v);
ins(u,v); ins(v,u);
}
scanf("%lld",&m); a[0]=n;
for(ll i=1;i<=m;i++) scanf("%lld",a+i);
mxbit[0]=-1;
for(ll i=1;i<(1<<m);i++)
mxbit[i]=mxbit[i>>1]+1;
for(ll i=1;i<(1<<m);i++)
{
for(ll j=1;j<=m;j++)
if(i&(1<<j-1)) sum[i]+=a[j-1]-a[j];
}
dfs(1,0);
ans=g[1][(1<<m)-1];
for(ll i=1;i<=m;i++) ans=(ans+f[1][(1<<m)-1][i])%mod;
printf("%lld",ans);
return 0;
}