P6478 题解

P6478 题解

可能巨佬们都觉得树形背包的时间复杂度分析太简单了,

好像都没写或者只是点了一句话,那我就来补充一下。

题意:

给定一棵点数为 \(n=2m\) 的有根树,每个点有 \(0,1\) 两种边权。

现在要为每一个权为 \(0\) 的点找一个权为 \(1\) 的点与之配对,并对每个 \(k \in [0,\frac {n}{2}]\)

求出恰有 \(k\) 对点的关系是祖先和后代的配对方案数。

做法:

我们记 \(f_k\) 为恰有 \(k\) 对点的关系是祖先和后代的配对方案数,

\(g_k\) 为钦定 \(k\) 对点的关系是祖先和后代的配对方案数,

那么我们有:\(g_k= \sum\limits_{i=k}^{m}\dbinom{i}{k}f_i\),式子的含义是:

先枚举一共有多少对祖先-后代点对,再枚举在这些点对中,我们钦定的是哪 \(k\) 对点。

我们发现,这个式子是一个二项式反演的经典形式,故我们有:

\(f_k=\sum\limits_{i=k}^{m}\dbinom{i}{k}(-1)^{i-k}g_i\),至于二项式反演的证明我就不讲了,相信大家也都会。

也就是说,我们将问题转化为了,对于每个 \(k \in [0,\frac {n}{2}]\)

求出钦定 \(k\) 对点的关系是祖先和后代的配对方案数。

这个问题就很好解决了,我们可以使用树上背包解决。具体来说,就是:

\(f_{u,i}\) 为子树 \(u\) 中选择了 \(i\) 对祖先-后代点对的方案数,则我们有一个显然的转移:

\(f_{u,i}=\sum\limits_{v,j\le i} f_{v,j}f_{u,i-j}\),代表在子树 \(v\) 中选择 \(j\) 对祖先-后代点,其余的在子树 \(u\) 其他部位选。

当然,我们还可以在子树 \(u\) 内选择一个点与 \(u\) 配对,即:

\(f_{u,i}=f_{u,i-1}+siz_{u,1-col_u}\),其中 \(col_u\) 代表点 \(u\) 属于谁,\(siz_{u,c}\) 代表子树 \(u\)\(col\) 值为 \(c\) 的点数。

注意这里需要倒序枚举 \(i\),因为点 \(u\) 只能被配对一次,故应该做 \(01\) 背包。

看上去这个做法是立方级别的做法,但实际上并不是,我们可以把复杂度的式子写下来。

先记 \(son_{u,x}\)\(u\) 的第 \(x\) 个子节点,以及 \(s(u)\) 为子树 \(u\) 的大小,则我们有:

\(T(n)=\sum\limits_{u}\sum\limits_{v=son_{u,i}}((\sum\limits_{j=1}^{i-1}s(son_{u,j})) \times s(v))\).

我们可以将上式理解成,在枚举到 \(u\)\(u\) 的第 \(i\) 个儿子 \(v\) 时,对于 \(u\) 的所有已经枚举过的儿子,

这些点以及它们子树中的所有点构成了一个点集 \(V\)

而我们正在进行的这一次枚举,会对时间复杂度造成 \(|V| \times s(v)\) 的贡献,

并将点 \(v\)\(v\) 的子树中所有点合并进点集 \(V\) 中。

我们考虑拆开贡献计算,具体来说,是:

因为每次枚举时,任一点集 \(V\) 中的点 \(p\),和任一点 \(v\) 及其子树中的点 \(q\)

都会对上面时间复杂度中的 \(|V| \times s(v)\) 造成值为 \(1\) 的贡献。

又因为 \(p\)\(q\) 在树上的最近公共祖先必然是 \(u\),也就是说,在这次枚举之前,

我们一定没有枚举到过 \(u,v\),故这是点对 \(p,q\) 第一次造成贡献,

但我们在枚举完 \(u,v\) 后,会把点 \(v\)\(v\) 的子树中所有点合并进点集 \(V\) 中,

那么这也是是点对 \(p,q\) 最后一次造成贡献,因为点集 \(V\) 不会分裂,

且任意同属于点集 \(V\) 中的两点不会造成贡献。

以上证明说明了,对 \(\forall u,v \in [1,n]\),点对 \(u,v\) 只会对时间复杂度造成一次值为 \(1\) 的贡献。

也就是说,最后时间复杂度只会被累加 \(O(n^2)\) 次,即 \(T(n)=O(n^2)\).

code:

#include<bits/stdc++.h>

#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define int long long
#define pii pair < int , int >
#define swap(u, v) u ^= v, v ^= u, u ^= v
#define ckmax(a, b) ((a) = max((a), (b)))
#define ckmin(a, b) ((a) = min((a), (b)))
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define edg(i, v, u) for (int i = head[u], v = e[i].to; i; i = e[i].nxt, v = e[i].to)

using namespace std;

inline int read() {
	int x = 0, f = 1; char ch = getchar();
	while (ch < '0' || ch > '9') f = ch == '-' ? -1 : 1, ch = getchar();
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - 48, ch = getchar();
	return x * f;
}

const int N (5e3 + 10);
const int mod (998244353);

void add (int &x, int y) { x = (x + y) % mod; }

int ksm (int a, int b) {
	int r = 1;
	for (; b; a = a * a % mod, b >>= 1) if (b & 1) r = r * a % mod;
	return r;
}

int n;
int cnt;
int g[N];
char s[N];
int fac[N];
int inv[N];
int head[N];
int f[N][N];
int siz[N][2];

struct Edge { int to, nxt; } e[N << 1];

void adde (int u, int v) {
	e[++cnt] = (Edge) {v, head[u]}, head[u] = cnt;
}

int Siz (int x) { return siz[x][0] + siz[x][1]; }

int C (int bi, int sm) {
	return fac[bi] * inv[sm] % mod * inv[bi - sm] % mod;
}

void dfs (int u, int fa) {
	f[u][0] = 1;
	rep (c, 0, 1) siz[u][c] = (s[u] == ('0' + c));
	edg (pth, v, u) if (v ^ fa) {
		dfs (v, u); int Su = Siz(u), Sv = Siz(v);
		rep (i, 0, Su + Sv) g[i] = 0;
		rep (i, 0, min (Su, n / 2)) rep (j, 0, min (Sv, n / 2 - i))
		  add (g[i + j], f[v][j] * f[u][i] % mod);
		rep (i, 0, Su + Sv) f[u][i] = g[i];
		rep (c, 0, 1) siz[u][c] += siz[v][c];
	}
	per (i, min (siz[u][0], siz[u][1]), 1) {
		int x = ((s[u] == '1') ? siz[u][0] : siz[u][1]) - (i - 1);
		add (f[u][i], f[u][i - 1] * x % mod);
	}
}

signed main() {
	n = read(), cin >> (s + 1);
	rep (i, 2, n) {
		int u = read(), v = read();
		adde (u, v), adde (v, u);
	}
	fac[0] = inv[0] = 1;
	rep (i, 1, n + 2) fac[i] = fac[i - 1] * i % mod, 
					  inv[i] = ksm (fac[i], mod - 2);
	dfs (1, 0);
	rep (k, 0, n / 2 + 1) (f[1][k] *= fac[n / 2 - k]) %= mod;
	rep (k, 0, n / 2) {
		int res = 0, sgn = mod - 1;
		rep (i, k, n / 2) {
			int tmp = C (i, k) * ((i - k & 1) ? mod - 1 : 1) % mod;
			res = (res + tmp * f[1][i] % mod) % mod;
		}
		cout << res << endl;
	}
	return 0;
}
posted @ 2021-12-05 16:57  GaryH  阅读(44)  评论(0编辑  收藏  举报