P6478 题解
可能巨佬们都觉得树形背包的时间复杂度分析太简单了,
(只有我想了若干小时还很晕)
好像都没写或者只是点了一句话,那我就来补充一下。
题意:
给定一棵点数为 n=2mn=2m 的有根树,每个点有 0,10,1 两种边权。
现在要为每一个权为 00 的点找一个权为 11 的点与之配对,并对每个 k \in [0,\frac {n}{2}]k∈[0,2n],
求出恰有 kk 对点的关系是祖先和后代的配对方案数。
做法:
我们记 f_kfk 为恰有 kk 对点的关系是祖先和后代的配对方案数,
记 g_kgk 为钦定 kk 对点的关系是祖先和后代的配对方案数,
那么我们有:g_k= \sum\limits_{i=k}^{m}\dbinom{i}{k}f_igk=i=k∑m(ki)fi,式子的含义是:
先枚举一共有多少对祖先-后代点对,再枚举在这些点对中,我们钦定的是哪 kk 对点。
(随您抄袭)
我们发现,这个式子是一个二项式反演的经典形式,故我们有:
f_k=\sum\limits_{i=k}^{m}\dbinom{i}{k}(-1)^{i-k}g_ifk=i=k∑m(ki)(−1)i−kgi,至于二项式反演的证明我就不讲了,相信大家也都会。
也就是说,我们将问题转化为了,对于每个 k \in [0,\frac {n}{2}]k∈[0,2n],
求出钦定 kk 对点的关系是祖先和后代的配对方案数。
这个问题就很好解决了,我们可以使用树上背包解决。具体来说,就是:
设 f_{u,i}fu,i 为子树 uu 中选择了 ii 对祖先-后代点对的方案数,则我们有一个显然的转移:
f_{u,i}=\sum\limits_{v,j\le i} f_{v,j}f_{u,i-j}fu,i=v,j≤i∑fv,jfu,i−j,代表在子树 vv 中选择 jj 对祖先-后代点,其余的在子树 uu 其他部位选。
当然,我们还可以在子树 uu 内选择一个点与 uu 配对,即:
f_{u,i}=f_{u,i-1}+siz_{u,1-col_u}fu,i=fu,i−1+sizu,1−colu,其中 col_ucolu 代表点 uu 属于谁,siz_{u,c}sizu,c 代表子树 uu 中 colcol 值为 cc 的点数。
注意这里需要倒序枚举 ii,因为点 uu 只能被配对一次,故应该做 0101 背包。
看上去这个做法是立方级别的做法,但实际上并不是,我们可以把复杂度的式子写下来。
先记 son_{u,x}sonu,x 为 uu 的第 xx 个子节点,以及 s(u)s(u) 为子树 uu 的大小,则我们有:
T(n)=\sum\limits_{u}\sum\limits_{v=son_{u,i}}((\sum\limits_{j=1}^{i-1}s(son_{u,j})) \times s(v))T(n)=u∑v=sonu,i∑((j=1∑i−1s(sonu,j))×s(v)).
我们可以将上式理解成,在枚举到 uu 和 uu 的第 ii 个儿子 vv 时,对于 uu 的所有已经枚举过的儿子,
这些点以及它们子树中的所有点构成了一个点集 VV;
而我们正在进行的这一次枚举,会对时间复杂度造成 |V| \times s(v)∣V∣×s(v) 的贡献,
并将点 vv 和 vv 的子树中所有点合并进点集 VV 中。
我们考虑拆开贡献计算,具体来说,是:
因为每次枚举时,任一点集 VV 中的点 pp,和任一点 vv 及其子树中的点 qq,
都会对上面时间复杂度中的 |V| \times s(v)∣V∣×s(v) 造成值为 11 的贡献。
又因为 pp 和 qq 在树上的最近公共祖先必然是 uu,也就是说,在这次枚举之前,
我们一定没有枚举到过 u,vu,v,故这是点对 p,qp,q 第一次造成贡献,
但我们在枚举完 u,vu,v 后,会把点 vv 和 vv 的子树中所有点合并进点集 VV 中,
那么这也是是点对 p,qp,q 最后一次造成贡献,因为点集 VV 不会分裂,
且任意同属于点集 VV 中的两点不会造成贡献。
以上证明说明了,对 \forall u,v \in [1,n]∀u,v∈[1,n],点对 u,vu,v 只会对时间复杂度造成一次值为 11 的贡献。
也就是说,最后时间复杂度只会被累加 O(n^2)O(n2) 次,即 T(n)=O(n^2)T(n)=O(n2).
code:
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define int long long
#define pii pair < int , int >
#define swap(u, v) u ^= v, v ^= u, u ^= v
#define ckmax(a, b) ((a) = max((a), (b)))
#define ckmin(a, b) ((a) = min((a), (b)))
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define edg(i, v, u) for (int i = head[u], v = e[i].to; i; i = e[i].nxt, v = e[i].to)
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') f = ch == '-' ? -1 : 1, ch = getchar();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
const int N (5e3 + 10);
const int mod (998244353);
void add (int &x, int y) { x = (x + y) % mod; }
int ksm (int a, int b) {
int r = 1;
for (; b; a = a * a % mod, b >>= 1) if (b & 1) r = r * a % mod;
return r;
}
int n;
int cnt;
int g[N];
char s[N];
int fac[N];
int inv[N];
int head[N];
int f[N][N];
int siz[N][2];
struct Edge { int to, nxt; } e[N << 1];
void adde (int u, int v) {
e[++cnt] = (Edge) {v, head[u]}, head[u] = cnt;
}
int Siz (int x) { return siz[x][0] + siz[x][1]; }
int C (int bi, int sm) {
return fac[bi] * inv[sm] % mod * inv[bi - sm] % mod;
}
void dfs (int u, int fa) {
f[u][0] = 1;
rep (c, 0, 1) siz[u][c] = (s[u] == ('0' + c));
edg (pth, v, u) if (v ^ fa) {
dfs (v, u); int Su = Siz(u), Sv = Siz(v);
rep (i, 0, Su + Sv) g[i] = 0;
rep (i, 0, min (Su, n / 2)) rep (j, 0, min (Sv, n / 2 - i))
add (g[i + j], f[v][j] * f[u][i] % mod);
rep (i, 0, Su + Sv) f[u][i] = g[i];
rep (c, 0, 1) siz[u][c] += siz[v][c];
}
per (i, min (siz[u][0], siz[u][1]), 1) {
int x = ((s[u] == '1') ? siz[u][0] : siz[u][1]) - (i - 1);
add (f[u][i], f[u][i - 1] * x % mod);
}
}
signed main() {
n = read(), cin >> (s + 1);
rep (i, 2, n) {
int u = read(), v = read();
adde (u, v), adde (v, u);
}
fac[0] = inv[0] = 1;
rep (i, 1, n + 2) fac[i] = fac[i - 1] * i % mod,
inv[i] = ksm (fac[i], mod - 2);
dfs (1, 0);
rep (k, 0, n / 2 + 1) (f[1][k] *= fac[n / 2 - k]) %= mod;
rep (k, 0, n / 2) {
int res = 0, sgn = mod - 1;
rep (i, k, n / 2) {
int tmp = C (i, k) * ((i - k & 1) ? mod - 1 : 1) % mod;
res = (res + tmp * f[1][i] % mod) % mod;
}
cout << res << endl;
}
return 0;
}