2022 牛客多校5 D(计数树dp)
2022 牛客多校5 D(计数树dp)
原题意非常谜语,翻成人话后是一个经典问题
题意
对一棵树,每个点染成黑白二色。求有多少个连通子图,满足叶子颜色相同。
叶子:子图中度数为1的点。
思路
树上黑白二染色的计数问题一般考虑树dp。该题比较特殊的地方在于,我们仅需叶子的染色情况。
如果我们定义 \(f[i][0/1]\) 表示以 \(i\) 为根,染色为 \(0/1\) 的答案。
设当前考虑根为 \(u\) ,根的颜色为 \(c\) ,对它的儿子 \(v\) 。如果考虑 \(f[u][c]\) 每一棵以 \(v\) 为根的子树都独立的贡献 \(f[v][c]+1\) 。
\[f[u][c] = \prod_{v \in S_{son}} (f[v][c]+1)
\]
特别地,当 \(u\) 是叶子时,\(f[u][c] = 1\)
但当考虑颜色 \(t = c \oplus 1\) 时,根度数不能为 \(1\) 。如果仍按如上转移,会出现 “根颜色为 \(c\),其他叶子颜色为 \(t\) ”的非法情况。
考虑减去这个非法情况。显然,只有当 \(u\) 仅连接一棵子树时 \(u\) 是叶子。因此我们可以用 \(sum\) 记录这个非法情况。
\[sum = sum + \sum_{v \in S_{son}} f[v][t]
\]
另外还有一个细节,\(f[u][t]\) 中实际上还记录了单点 \(u\) 的情况,这也是我们不希望要的,因此总的转移方程为
\[f[u][t] = \prod_{v \in S_{son}} (f[v][t]+1)-1
\]
最后答案为 \(f[i][0/1]\) 的求和减去 \(sum\)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<random>
#include<iomanip>
#include<functional>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3f
#define ull unsigned long long
#define endl '\n'
#define int long long
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
void solve()
{
int n; cin >> n;
string s; cin >> s;
s = ' ' + s;
vector<vector<int>> adj(n + 1);
for(int i = 0;i < n - 1;i += 1) {
int u,v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector f(n + 1,vector(2,0ll));
int sum = 0;
function<void(int,int)> dfs = [&](int u,int fa) {
f[u][0] = f[u][1] = 1;
int t = (s[u] - '0') ^ 1;
for(auto v : adj[u]) if(v != fa) {
dfs(v,u);
for(int c = 0;c < 2;c += 1) {
f[u][c] *= f[v][c] + 1;
f[u][c] %= mod;
}
sum += f[v][t];
sum %= mod;
}
f[u][t] = (f[u][t] - 1) % mod;
f[u][t] = (f[u][t] + mod) % mod;
};
dfs(1,0);
int ans = 0;
for(int i = 1;i <= n;i += 1) {
ans = (ans + f[i][0] + f[i][1]) % mod;
}
ans = ((ans - sum) % mod + mod) % mod;
cout << ans;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//int T;cin>>T;
//while(T--)
solve();
return 0;
}