Luogu P11363 NOIP2024 树的遍历 题解 [ 紫 ] [ 树形 dp ] [ 组合计数 ] [ adhoc ]

树上遍历:CCF 难得一遇的好题!

参考了洛谷的第一篇题解,所以思路会有点相似。

部分分

k=1 时,显然方案总数为 i=1n(di1)!,因为进入一个子节点后可以以任意顺序遍历它的所有出边。

观察

当遍历出来的树的形态确定时,能形成这棵树的边组合在一起一定是一条链,且这条链是这棵树上两个不同的叶子节点之间的链。

这个手模几组样例就应该理解了,证明比较感性,观察可得一个节点的所有边一定在遍历出的树中是一条链。所以我们一定要从一个边以一条链的路径走到其他的所有边,于是这条链上中间的边就是无法作为起点的,因为去了某个边之后就回不来了。

每个遍历生成树都一定都这样的一条链,且一旦确定这条链,那么生成的新树就有 i=1n(di1[iV])! 种,其中 V 表示这条链上的节点。理解就是进入一个子节点之后,在链上的点已经确定了它从哪里来,并且确定了它最后走的边是哪条,因此是 di2

这个公式可以等价转化为 i=1n(di1)!×i=1|V|(dVi1)1,于是我们就可以开始树形 dp 了。

因为 i=1n(di1)! 的系数是所有链都要乘的,所以我们把它提出来放到最后再乘。

树形 dp

显然,现在这个问题已经被转化为了求解带权的每条连接两个叶子链的总和是多少。

我们设计 dpi,0/1 表示以节点 i 为根的子树中目前一共有多少种合法子树方案,且当前子树中有没有关键边。

显然,我们可以每次遍历 i 节点的一个子树之后,在他们的 LCA,即节点 i 处统计答案,即可不重不漏。

先能写出遍历到当前子树的答案:

  1. 当连接子树的边是关键边时,此时一定满足统计进答案的标准,那么 res=res+(dpv,0+dpv,1)×(dpi,0+dpi,1),可以通过把他们乘开来理解这个式子。
  2. 当连接子树的边是关键边时,此时不一定满足统计进答案的标准,那么就要让前面遍历的子树或者当前子树中至少存在一条关键边,则 res=res+dpi,0×dpv,1+dpi,1×dpv,1+dpi,1×dpv,0

接下来考虑合并子树的 dp 值:

  1. 当连接的子树是关键边时,此时以 i 为根的子树内一定含有关键边,那么只转移一个式子即可:

dpi,1=dpi,1+dpv,0+dpv,1

  1. 当连接的子树不是关键边时,就都能转移:

dpi,1=dpi,1+dpv,1

dpi,0=dpi,0+dpv,0

注意在最后要乘上 (di1)1,包括 resdp 值,才能保证求的是这个式子。

如果当前遍历到了叶子节点,那么注意在 dp 完后把 dpi,0 赋值为 1,才能统计到答案。这也是为什么不能从一个叶子节点开始 dfs 的原因,会漏加这个 1

时间复杂度 O(Tn)

代码

链式前向星记得开双倍空间!!!

注意特判 n=2

#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pi;
const ll mod=1e9+7;
const int N=200005;
ll inv[N+100];
int n,m,eu[N],ev[N];
int h[N],e[N],ne[N],idx,id[N],d[N];
ll ans=0,dp[N][2];
bitset<N>sig;
void init()
{
    sig.reset();
    memset(h,-1,sizeof(h));
    memset(d,0,sizeof(d));
    idx=0;
    ans=0;
}
void add(int u,int v,int x)
{
    idx++;
    ne[idx]=h[u];
    h[u]=idx;
    e[idx]=v;
    id[idx]=x;
    d[u]++;
}
void dfs(int u,int fa)
{
    ll res=0;
    dp[u][0]=dp[u][1]=0;
    for(int i=h[u];i!=-1;i=ne[i])
    {
        int v=e[i],x=id[i];
        if(v==fa)continue;
        dfs(v,u);
        if(sig[x])
        {
            res=(res+(dp[v][1]+dp[v][0])*(dp[u][1]+dp[u][0])%mod)%mod;
            dp[u][1]=(dp[u][1]+dp[v][0]+dp[v][1])%mod;
        }
        else
        {
            res=(res+dp[u][1]*dp[v][0]%mod+dp[u][1]*dp[v][1]%mod+dp[u][0]*dp[v][1]%mod)%mod;
            dp[u][1]=(dp[u][1]+dp[v][1])%mod;
            dp[u][0]=(dp[u][0]+dp[v][0])%mod;
        }
    }
    if(d[u]==1)dp[u][0]++;
    ans=(ans+res*inv[d[u]-1]%mod)%mod;
    dp[u][0]=(dp[u][0]*inv[d[u]-1])%mod;
    dp[u][1]=(dp[u][1]*inv[d[u]-1])%mod;
}
void solve()
{
    scanf("%d%d",&n,&m);
    init();
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&eu[i],&ev[i]);
        add(eu[i],ev[i],i);
        add(ev[i],eu[i],i);
    }
    for(int i=1;i<=m;i++)
    {
        int x;
        scanf("%d",&x);
        sig[x]=1;
    }
    if(n<=3)
    {
        printf("1\n");
        return;
    }
    for(int i=1;i<=n;i++)
    {
        if(d[i]>1)
        {
            dfs(i,-1);
            break;
        }
    }
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<d[i];j++)
        {
            ans=(ans*j)%mod;
        }
    }
    printf("%lld\n",ans);
}
int main()
{
    freopen("traverse.in","r",stdin);
    freopen("traverse.out","w",stdout);
    inv[0]=inv[1]=1;
    for(int i=2;i<=N;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    int c,t;
    scanf("%d%d",&c,&t);
    while(t--)solve();
    return 0;
}
posted @   KS_Fszha  阅读(61)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 推荐几款开源且免费的 .NET MAUI 组件库
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· Trae初体验
点击右上角即可分享
微信分享提示