ARC121E Directed Tree

ARC121E Directed Tree

一个点如果想要合法,首先不能和子树中的点重复,其次不能填子树中的点的编号 —— 这会有后效性,使得我们非常困难地处理 DP 转移。但是如果我们考虑不合法,十分简单 —— 只需要填入子树中的其中一个,并且不和子树中不合法的点填同样的数字即可。

正难则反,看成有多个限制形如 ii 不能在 ii 的子树除自己中出现,考虑容斥,记 F(i)F(i) 表示有 ii 个不合法点,其他不确定的方案数,答案即为 i=0n(1)i×F(i)\sum\limits_{i=0}^{n}(-1)^i\times F(i)

考虑 DP,设 fu,if_{u,i} 表示 uu 的子树内有 ii 个不合法点的情况,转移时先统计儿子的方案数。

fu,a+bfu,a+b+fu,a×fv,bf_{u,a+b}\leftarrow f_{u,a+b}+f'_{u,a}\times f_{v,b}

再考虑转移 uu 点的方案数,uu 能填的不合法位置有 sizu1siz_u-1 个,目前已经填了 i1i-1 个,有 sizuisiz_u-i 个能填的位置,ii 从大到小更新不重复。

fu,ifu,i+(sizui)fu,i1f_{u,i}\leftarrow f_{u,i}+(siz_u-i)f_{u,i-1}

最后答案显然是 i=0n(1)i×f1,i×(ni)!\sum\limits_{i=0}^{n}(-1)^i \times f_{1,i} \times (n-i)!,其中 f1,0=1f_{1,0}=1

时间复杂度 O(n2)\mathcal O(n^2)

#include <bits/stdc++.h>

using namespace std;

#define int long long
#define he putchar('\n')
#define ha putchar(' ')

typedef long long ll;

inline int read()
{
	int x = 0, f = 1;
	char c = getchar();
	while(c < '0' || c > '9')
	{
		if(c == '-') f = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9')
		x = x * 10 + c - 48, c = getchar();
	return x * f;
}

inline void write(int x)
{
	if(x < 0)
	{
		x = -x;
		putchar('-');
	}
	if(x > 9) write(x / 10);
	putchar(x % 10 + 48);
}

const int _ = 2010, mod = 998244353;

int n, ans, fac[_], f[_][_], siz[_], t[_];

vector<int> d[_];

void dfs(int u, int fa)
{
	f[u][0] = 1;
	for(int v : d[u])
	{
		if(v == fa) continue;
		dfs(v, u);
		for(int j = 0; j <= siz[u]; ++j)
			for(int k = 0; k <= siz[v]; ++k)
				t[j + k] = (t[j + k] + f[u][j] * f[v][k] % mod) % mod;
		siz[u] += siz[v];
		for(int j = 0; j <= siz[u]; ++j) f[u][j] = t[j], t[j] = 0;
	}
	++siz[u];
	for(int i = siz[u] - 1; i >= 0; --i)
		f[u][i + 1] = (f[u][i + 1] + (siz[u] - i - 1) * f[u][i] % mod) % mod;
}

signed main()
{
	n = read();
	for(int i = 2, x; i <= n; ++i)
	{
		x = read();
		d[i].push_back(x), d[x].push_back(i);
	}
	dfs(1, 0);
	fac[0] = 1;
	for(int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
	for(int i = 0; i <= n; ++i) ans = (ans + (i & 1 ? -1 : 1) * f[1][i] * fac[n - i] % mod + mod) % mod;//, cout << ans << "!!!\n";
	write((ans % mod + mod) % mod), he;
	return 0;
}
posted @ 2022-07-25 15:33  蒟蒻orz  阅读(2)  评论(0编辑  收藏  举报  来源