2022 牛客多校5 D(计数树dp)

2022 牛客多校5 D(计数树dp)

原题意非常谜语,翻成人话后是一个经典问题

题意

对一棵树,每个点染成黑白二色。求有多少个连通子图,满足叶子颜色相同。

叶子:子图中度数为1的点。

思路

树上黑白二染色的计数问题一般考虑树dp。该题比较特殊的地方在于,我们仅需叶子的染色情况。

如果我们定义 \(f[i][0/1]\) 表示以 \(i\) 为根,染色为 \(0/1\) 的答案。

设当前考虑根为 \(u\) ,根的颜色为 \(c\) ,对它的儿子 \(v\) 。如果考虑 \(f[u][c]\) 每一棵以 \(v\) 为根的子树都独立的贡献 \(f[v][c]+1\)

\[f[u][c] = \prod_{v \in S_{son}} (f[v][c]+1) \]

特别地,当 \(u\) 是叶子时,\(f[u][c] = 1\)

但当考虑颜色 \(t = c \oplus 1\) 时,根度数不能为 \(1\) 。如果仍按如上转移,会出现 “根颜色为 \(c\),其他叶子颜色为 \(t\) ”的非法情况。

考虑减去这个非法情况。显然,只有当 \(u\) 仅连接一棵子树时 \(u\) 是叶子。因此我们可以用 \(sum\) 记录这个非法情况。

\[sum = sum + \sum_{v \in S_{son}} f[v][t] \]

另外还有一个细节,\(f[u][t]\) 中实际上还记录了单点 \(u\) 的情况,这也是我们不希望要的,因此总的转移方程为

\[f[u][t] = \prod_{v \in S_{son}} (f[v][t]+1)-1 \]

最后答案为 \(f[i][0/1]\) 的求和减去 \(sum\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<random>
#include<iomanip>
#include<functional>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3f
#define ull unsigned long long
#define endl '\n'
#define int long long
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;

void solve()
{    
    int n; cin >> n;
    string s; cin >> s;
    s = ' ' + s;
    vector<vector<int>> adj(n + 1);
    for(int i = 0;i < n - 1;i += 1) {
        int u,v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    vector f(n + 1,vector(2,0ll));
    int sum = 0;
    function<void(int,int)> dfs = [&](int u,int fa) {
        f[u][0] = f[u][1] = 1;
        int t = (s[u] - '0') ^ 1;
        for(auto v : adj[u]) if(v != fa) {
            dfs(v,u);
            for(int c = 0;c < 2;c += 1) {
                f[u][c] *= f[v][c] + 1;
                f[u][c] %= mod;
            }
            sum += f[v][t];
            sum %= mod;
        }
        f[u][t] = (f[u][t] - 1) % mod;
        f[u][t] = (f[u][t] + mod) % mod;
    };
    dfs(1,0);

    int ans = 0;
    for(int i = 1;i <= n;i += 1) {
        ans = (ans + f[i][0] + f[i][1]) % mod;
    }
    ans = ((ans - sum) % mod + mod) % mod;
    
    cout << ans;
}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);

    //int T;cin>>T;
    //while(T--)
        solve();

    return 0;
}
posted @ 2022-08-02 19:59  Mxrurush  阅读(18)  评论(1编辑  收藏  举报