P6478 题解

可能巨佬们都觉得树形背包的时间复杂度分析太简单了,

(只有我想了若干小时还很晕)

好像都没写或者只是点了一句话,那我就来补充一下。

题意:

给定一棵点数为 n=2mn=2m 的有根树,每个点有 0,10,1 两种边权。

现在要为每一个权为 00 的点找一个权为 11 的点与之配对,并对每个 k \in [0,\frac {n}{2}]k[0,2n],

求出恰有 kk 对点的关系是祖先和后代的配对方案数。

做法:

我们记 f_kfk 为恰有 kk 对点的关系是祖先和后代的配对方案数,

记 g_kgk 为钦定 kk 对点的关系是祖先和后代的配对方案数,

那么我们有:g_k= \sum\limits_{i=k}^{m}\dbinom{i}{k}f_igk=i=km(ki)fi,式子的含义是:

先枚举一共有多少对祖先-后代点对,再枚举在这些点对中,我们钦定的是哪 kk 对点。

(随您抄袭)

我们发现,这个式子是一个二项式反演的经典形式,故我们有:

f_k=\sum\limits_{i=k}^{m}\dbinom{i}{k}(-1)^{i-k}g_ifk=i=km(ki)(1)ikgi,至于二项式反演的证明我就不讲了,相信大家也都会。

也就是说,我们将问题转化为了,对于每个 k \in [0,\frac {n}{2}]k[0,2n],

求出钦定 kk 对点的关系是祖先和后代的配对方案数。

这个问题就很好解决了,我们可以使用树上背包解决。具体来说,就是:

设 f_{u,i}fu,i 为子树 uu 中选择了 ii 对祖先-后代点对的方案数,则我们有一个显然的转移:

f_{u,i}=\sum\limits_{v,j\le i} f_{v,j}f_{u,i-j}fu,i=v,jifv,jfu,ij,代表在子树 vv 中选择 jj 对祖先-后代点,其余的在子树 uu 其他部位选。

当然,我们还可以在子树 uu 内选择一个点与 uu 配对,即:

f_{u,i}=f_{u,i-1}+siz_{u,1-col_u}fu,i=fu,i1+sizu,1colu,其中 col_ucolu 代表点 uu 属于谁,siz_{u,c}sizu,c 代表子树 uu 中 colcol 值为 cc 的点数。

注意这里需要倒序枚举 ii,因为点 uu 只能被配对一次,故应该做 0101 背包。

看上去这个做法是立方级别的做法,但实际上并不是,我们可以把复杂度的式子写下来。

先记 son_{u,x}sonu,x 为 uu 的第 xx 个子节点,以及 s(u)s(u) 为子树 uu 的大小,则我们有:

T(n)=\sum\limits_{u}\sum\limits_{v=son_{u,i}}((\sum\limits_{j=1}^{i-1}s(son_{u,j})) \times s(v))T(n)=uv=sonu,i((j=1i1s(sonu,j))×s(v)).

我们可以将上式理解成,在枚举到 uu 和 uu 的第 ii 个儿子 vv 时,对于 uu 的所有已经枚举过的儿子,

这些点以及它们子树中的所有点构成了一个点集 VV;

而我们正在进行的这一次枚举,会对时间复杂度造成 |V| \times s(v)V×s(v) 的贡献,

并将点 vv 和 vv 的子树中所有点合并进点集 VV 中。

我们考虑拆开贡献计算,具体来说,是:

因为每次枚举时,任一点集 VV 中的点 pp,和任一点 vv 及其子树中的点 qq,

都会对上面时间复杂度中的 |V| \times s(v)V×s(v) 造成值为 11 的贡献。

又因为 pp 和 qq 在树上的最近公共祖先必然是 uu,也就是说,在这次枚举之前,

我们一定没有枚举到过 u,vu,v,故这是点对 p,qp,q 第一次造成贡献,

但我们在枚举完 u,vu,v 后,会把点 vv 和 vv 的子树中所有点合并进点集 VV 中,

那么这也是是点对 p,qp,q 最后一次造成贡献,因为点集 VV 不会分裂,

且任意同属于点集 VV 中的两点不会造成贡献。

以上证明说明了,对 \forall u,v \in [1,n]u,v[1,n],点对 u,vu,v 只会对时间复杂度造成一次值为 11 的贡献。

也就是说,最后时间复杂度只会被累加 O(n^2)O(n2) 次,即 T(n)=O(n^2)T(n)=O(n2).

code:

#include<bits/stdc++.h>

#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define int long long
#define pii pair < int , int >
#define swap(u, v) u ^= v, v ^= u, u ^= v
#define ckmax(a, b) ((a) = max((a), (b)))
#define ckmin(a, b) ((a) = min((a), (b)))
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define edg(i, v, u) for (int i = head[u], v = e[i].to; i; i = e[i].nxt, v = e[i].to)

using namespace std;

inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') f = ch == '-' ? -1 : 1, ch = getchar();
    while (ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

const int N (5e3 + 10);
const int mod (998244353);

void add (int &x, int y) { x = (x + y) % mod; }

int ksm (int a, int b) {
    int r = 1;
    for (; b; a = a * a % mod, b >>= 1) if (b & 1) r = r * a % mod;
    return r;
}

int n;
int cnt;
int g[N];
char s[N];
int fac[N];
int inv[N];
int head[N];
int f[N][N];
int siz[N][2];

struct Edge { int to, nxt; } e[N << 1];

void adde (int u, int v) {
    e[++cnt] = (Edge) {v, head[u]}, head[u] = cnt;
}

int Siz (int x) { return siz[x][0] + siz[x][1]; }

int C (int bi, int sm) {
    return fac[bi] * inv[sm] % mod * inv[bi - sm] % mod;
}

void dfs (int u, int fa) {
    f[u][0] = 1;
    rep (c, 0, 1) siz[u][c] = (s[u] == ('0' + c));
    edg (pth, v, u) if (v ^ fa) {
        dfs (v, u); int Su = Siz(u), Sv = Siz(v);
        rep (i, 0, Su + Sv) g[i] = 0;
        rep (i, 0, min (Su, n / 2)) rep (j, 0, min (Sv, n / 2 - i))
          add (g[i + j], f[v][j] * f[u][i] % mod);
        rep (i, 0, Su + Sv) f[u][i] = g[i];
        rep (c, 0, 1) siz[u][c] += siz[v][c];
    }
    per (i, min (siz[u][0], siz[u][1]), 1) {
        int x = ((s[u] == '1') ? siz[u][0] : siz[u][1]) - (i - 1);
        add (f[u][i], f[u][i - 1] * x % mod);
    }
}

signed main() {
    n = read(), cin >> (s + 1);
    rep (i, 2, n) {
        int u = read(), v = read();
        adde (u, v), adde (v, u);
    }
    fac[0] = inv[0] = 1;
    rep (i, 1, n + 2) fac[i] = fac[i - 1] * i % mod, 
                      inv[i] = ksm (fac[i], mod - 2);
    dfs (1, 0);
    rep (k, 0, n / 2 + 1) (f[1][k] *= fac[n / 2 - k]) %= mod;
    rep (k, 0, n / 2) {
        int res = 0, sgn = mod - 1;
        rep (i, k, n / 2) {
            int tmp = C (i, k) * ((i - k & 1) ? mod - 1 : 1) % mod;
            res = (res + tmp * f[1][i] % mod) % mod;
        }
        cout << res << endl;
    }
    return 0;
}