【luogu CF1707D】Partial Virtual Trees(容斥)(DP)

Partial Virtual Trees

题目链接:luogu CF1707D

题目大意

给你一棵以 1 为根的数,问你对于每个长度,有多少个点集序列,第一个点集是全部点,最后一个点集只有 1 号点,且中间每个点集都是上一个点集的真子集,而且每个点集内两两点的 LCA 都在点集中。

思路

发现其实这个每次是真子集不太好处理,我们考虑如果不是真子集,就是子集。
那如果求出来了,我们不难通过容斥得到真子集的答案。
\(f_i=g_i-\sum\limits_{j=1}^{i-1}f_{i-1}\binom{i}{j}\)
(就是选那几次操作删的点集是空的)

那考虑子集怎么弄,考虑删的过程,要怎么删才能满足是虚树。
那考虑树上的 DP,删一个子树的根会在什么时候,有两种情况:
一是最后删,二是先删剩一个子树,然后把它删了,再删剩下的子树。
那你可以设 \(f_{i,j}\)\(j\) 步删完 \(i\) 的子树的方案数。
然后不难首先最后删其实就是 \(\prod\limits_{son}(\sum\limits_{k=1}^{j}f_{son,k})\)
那这个用前缀和维护一个 \(g_{i,j}=\sum\limits_{k=1}^jf_{i,k}\) 即可求。

那第二个的话你枚举剩下的子树,那其他子树的方案数就是也是上面那些乘起来(不能乘你剩下的那个),不过要注意的是这些只能是不超过 \(j-1\) 次就删完。
那你就把前面那个东西也前缀和一下(先枚举剩下的子树,再枚举删的次数),跟 \(f_{sp_{son},j}\) 乘起来就可以共吸纳给 \(f_{i,j}\) 啦。
然后我们给容斥用的就是 \(f_{1,i}\)

不过要注意的一点是第二个删法在当前根是 \(1\) 的时候不能用,因为你最后要留下的是 \(1\),换而言之 \(1\) 是要最晚删去的。

代码

#include<cstdio>

using namespace std;

const int N = 2e3 + 100;
struct node {
	int to, nxt;
}e[N << 1];
int n, mo, le[N], KK, sz[N], tmp[N];
int jc[N], inv[N], invs[N], f[N][N], g[N][N];
int sum[N], s0[N], ans[N];

int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
int mul(int x, int y) {return 1ll * x * y % mo;}
int C(int n, int m) {
	if (n < 0 || m < 0 || n < m) return 0;
	return mul(jc[n], mul(invs[m], invs[n - m]));
}
int ksm(int x, int y) {
	int re = 1;
	while (y) {
		if (y & 1) re = mul(re, x);
		x = mul(x, x); y >>= 1;
	}
	return re;
}
int Inv(int x) {return ksm(x, mo - 2);}

void Add(int x, int y) {
	e[++KK] = (node){y, le[x]}; le[x] = KK;
}

void work(int now, int father) {
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father)
			work(e[i].to, now);
	for (int i = 1; i <= n; i++) {
		sum[i] = 1; s0[i] = 0;
		for (int j = le[now]; j; j = e[j].nxt)
			if (e[j].to != father) {
				if (!g[e[j].to][i]) s0[i]++;
					else sum[i] = mul(sum[i], g[e[j].to][i]);
			}
	}
	for (int i = 1; i <= n; i++)
		if (s0[i]) f[now][i] = 0;
			else f[now][i] = sum[i];
	if (now != 1) {
		for (int i = le[now]; i; i = e[i].nxt)
			if (e[i].to != father) {
				for (int j = 1; j <= n; j++) {
					if (!g[e[i].to][j]) s0[j]--;
						else sum[j] = mul(sum[j], Inv(g[e[i].to][j]));
					if (!s0[j]) tmp[j] = add(tmp[j - 1], sum[j]);
						else tmp[j] = tmp[j - 1];
					f[now][j] = add(f[now][j], mul(f[e[i].to][j], tmp[j - 1]));
					if (!g[e[i].to][j]) s0[j]++;
						else sum[j] = mul(sum[j], g[e[i].to][j]);
				}
			}
	}
	for (int i = 1; i <= n; i++) g[now][i] = add(g[now][i - 1], f[now][i]);
}

int main() {
	scanf("%d %d", &n, &mo);
	for (int i = 1; i < n; i++) {
		int x, y; scanf("%d %d", &x, &y);
		Add(x, y); Add(y, x);
	}
	
	jc[0] = 1; for (int i = 1; i < N; i++) jc[i] = mul(jc[i - 1], i);
	inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = mul(inv[mo % i], mo - mo / i);
	invs[0] = 1; for (int i = 1; i < N; i++) invs[i] = mul(invs[i - 1], inv[i]);
	
	work(1, 0);
	for (int i = 1; i < n; i++) {
		ans[i] = f[1][i];
		for (int j = 1; j < i; j++)
			ans[i] = dec(ans[i], mul(ans[j], C(i, j)));
		printf("%d ", ans[i]);
	}
	
	return 0;
}
posted @ 2023-01-14 01:54  あおいSakura  阅读(19)  评论(0编辑  收藏  举报