P6478 题解
P6478 题解
可能巨佬们都觉得树形背包的时间复杂度分析太简单了,
好像都没写或者只是点了一句话,那我就来补充一下。
题意:
给定一棵点数为 \(n=2m\) 的有根树,每个点有 \(0,1\) 两种边权。
现在要为每一个权为 \(0\) 的点找一个权为 \(1\) 的点与之配对,并对每个 \(k \in [0,\frac {n}{2}]\),
求出恰有 \(k\) 对点的关系是祖先和后代的配对方案数。
做法:
我们记 \(f_k\) 为恰有 \(k\) 对点的关系是祖先和后代的配对方案数,
记 \(g_k\) 为钦定 \(k\) 对点的关系是祖先和后代的配对方案数,
那么我们有:\(g_k= \sum\limits_{i=k}^{m}\dbinom{i}{k}f_i\),式子的含义是:
先枚举一共有多少对祖先-后代点对,再枚举在这些点对中,我们钦定的是哪 \(k\) 对点。
我们发现,这个式子是一个二项式反演的经典形式,故我们有:
\(f_k=\sum\limits_{i=k}^{m}\dbinom{i}{k}(-1)^{i-k}g_i\),至于二项式反演的证明我就不讲了,相信大家也都会。
也就是说,我们将问题转化为了,对于每个 \(k \in [0,\frac {n}{2}]\),
求出钦定 \(k\) 对点的关系是祖先和后代的配对方案数。
这个问题就很好解决了,我们可以使用树上背包解决。具体来说,就是:
设 \(f_{u,i}\) 为子树 \(u\) 中选择了 \(i\) 对祖先-后代点对的方案数,则我们有一个显然的转移:
\(f_{u,i}=\sum\limits_{v,j\le i} f_{v,j}f_{u,i-j}\),代表在子树 \(v\) 中选择 \(j\) 对祖先-后代点,其余的在子树 \(u\) 其他部位选。
当然,我们还可以在子树 \(u\) 内选择一个点与 \(u\) 配对,即:
\(f_{u,i}=f_{u,i-1}+siz_{u,1-col_u}\),其中 \(col_u\) 代表点 \(u\) 属于谁,\(siz_{u,c}\) 代表子树 \(u\) 中 \(col\) 值为 \(c\) 的点数。
注意这里需要倒序枚举 \(i\),因为点 \(u\) 只能被配对一次,故应该做 \(01\) 背包。
看上去这个做法是立方级别的做法,但实际上并不是,我们可以把复杂度的式子写下来。
先记 \(son_{u,x}\) 为 \(u\) 的第 \(x\) 个子节点,以及 \(s(u)\) 为子树 \(u\) 的大小,则我们有:
\(T(n)=\sum\limits_{u}\sum\limits_{v=son_{u,i}}((\sum\limits_{j=1}^{i-1}s(son_{u,j})) \times s(v))\).
我们可以将上式理解成,在枚举到 \(u\) 和 \(u\) 的第 \(i\) 个儿子 \(v\) 时,对于 \(u\) 的所有已经枚举过的儿子,
这些点以及它们子树中的所有点构成了一个点集 \(V\);
而我们正在进行的这一次枚举,会对时间复杂度造成 \(|V| \times s(v)\) 的贡献,
并将点 \(v\) 和 \(v\) 的子树中所有点合并进点集 \(V\) 中。
我们考虑拆开贡献计算,具体来说,是:
因为每次枚举时,任一点集 \(V\) 中的点 \(p\),和任一点 \(v\) 及其子树中的点 \(q\),
都会对上面时间复杂度中的 \(|V| \times s(v)\) 造成值为 \(1\) 的贡献。
又因为 \(p\) 和 \(q\) 在树上的最近公共祖先必然是 \(u\),也就是说,在这次枚举之前,
我们一定没有枚举到过 \(u,v\),故这是点对 \(p,q\) 第一次造成贡献,
但我们在枚举完 \(u,v\) 后,会把点 \(v\) 和 \(v\) 的子树中所有点合并进点集 \(V\) 中,
那么这也是是点对 \(p,q\) 最后一次造成贡献,因为点集 \(V\) 不会分裂,
且任意同属于点集 \(V\) 中的两点不会造成贡献。
以上证明说明了,对 \(\forall u,v \in [1,n]\),点对 \(u,v\) 只会对时间复杂度造成一次值为 \(1\) 的贡献。
也就是说,最后时间复杂度只会被累加 \(O(n^2)\) 次,即 \(T(n)=O(n^2)\).
code:
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define int long long
#define pii pair < int , int >
#define swap(u, v) u ^= v, v ^= u, u ^= v
#define ckmax(a, b) ((a) = max((a), (b)))
#define ckmin(a, b) ((a) = min((a), (b)))
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define edg(i, v, u) for (int i = head[u], v = e[i].to; i; i = e[i].nxt, v = e[i].to)
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') f = ch == '-' ? -1 : 1, ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
const int N (5e3 + 10);
const int mod (998244353);
void add (int &x, int y) { x = (x + y) % mod; }
int ksm (int a, int b) {
int r = 1;
for (; b; a = a * a % mod, b >>= 1) if (b & 1) r = r * a % mod;
return r;
}
int n;
int cnt;
int g[N];
char s[N];
int fac[N];
int inv[N];
int head[N];
int f[N][N];
int siz[N][2];
struct Edge { int to, nxt; } e[N << 1];
void adde (int u, int v) {
e[++cnt] = (Edge) {v, head[u]}, head[u] = cnt;
}
int Siz (int x) { return siz[x][0] + siz[x][1]; }
int C (int bi, int sm) {
return fac[bi] * inv[sm] % mod * inv[bi - sm] % mod;
}
void dfs (int u, int fa) {
f[u][0] = 1;
rep (c, 0, 1) siz[u][c] = (s[u] == ('0' + c));
edg (pth, v, u) if (v ^ fa) {
dfs (v, u); int Su = Siz(u), Sv = Siz(v);
rep (i, 0, Su + Sv) g[i] = 0;
rep (i, 0, min (Su, n / 2)) rep (j, 0, min (Sv, n / 2 - i))
add (g[i + j], f[v][j] * f[u][i] % mod);
rep (i, 0, Su + Sv) f[u][i] = g[i];
rep (c, 0, 1) siz[u][c] += siz[v][c];
}
per (i, min (siz[u][0], siz[u][1]), 1) {
int x = ((s[u] == '1') ? siz[u][0] : siz[u][1]) - (i - 1);
add (f[u][i], f[u][i - 1] * x % mod);
}
}
signed main() {
n = read(), cin >> (s + 1);
rep (i, 2, n) {
int u = read(), v = read();
adde (u, v), adde (v, u);
}
fac[0] = inv[0] = 1;
rep (i, 1, n + 2) fac[i] = fac[i - 1] * i % mod,
inv[i] = ksm (fac[i], mod - 2);
dfs (1, 0);
rep (k, 0, n / 2 + 1) (f[1][k] *= fac[n / 2 - k]) %= mod;
rep (k, 0, n / 2) {
int res = 0, sgn = mod - 1;
rep (i, k, n / 2) {
int tmp = C (i, k) * ((i - k & 1) ? mod - 1 : 1) % mod;
res = (res + tmp * f[1][i] % mod) % mod;
}
cout << res << endl;
}
return 0;
}