题解 数树

传送门

给定一些点之间的约束,求合法方案数
看起来挺套路的,考虑容斥
\(s[i]\) 为至少有 \(i\) 个约束条件不满足时的方案数
注意这里的「方案数」指的是选出 \(i\) 条边,令它们不满足约束的选法数
发现如果有两条边的起点或终点是同一个点会炸锅,所以这个选法数不能组合数算
那就树形DP好了
第一个思路是令 \(dp[i][j][0/1/2]\) 表示以 \(i\) 为根的子树,子树内至少有 \(j\) 条边不满足约束,点 \(i\) 与父节点间的边为 不选/向上/向下 时的选法数
但好像没法转移
所以根据题解换成了 \(dp[i][j][0/1/2/3]\) 表示点 \(i\) 没有连边/连了一条入边/连了一条出边/连了一条入边和一条出边 时的选法数
考虑转移(这里是子树合并):
首先如果 \(u\)\(v\) 间的这条边不选,那 \(v\) 的四种情况对 \(u\) 的四种情况都能产生贡献
然后考虑剩下的情况
那转移大概是这样的
image
然后就是容斥那块
这里求了至少,想要恰好,所以二项式反演就好

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5010
#define ll long long 
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int head[N], size, sta[N], top, siz[N];
ll dp[N][N][4], tem[N][4], s[N], fac[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll md(ll a) {return a>=mod?a-mod:a;}
struct edge{int to, next; bool up;}e[N<<1];
inline void add(int s, int t, bool v) {e[++size].to=t; e[size].next=head[s]; e[size].up=v; head[s]=size;}

void dfs(int u, int fa) {
	siz[u]=1; dp[u][0][0]=1;
	for (int i=head[u],v; ~i; i=e[i].next) {
		v = e[i].to;
		if (v==fa) continue;
		dfs(v, u);
		ll sum;
		memset(tem, 0, sizeof(tem));
		for (int j=siz[u]; ~j; --j) {
			for (int k=siz[v]; ~k; --k) {
				sum=0;
				for (int h=0; h<4; ++h) md(sum, dp[v][k][h]);
				for (int h=0; h<4; ++h) md(tem[j+k][h], dp[u][j][h]*sum%mod);
				if (!e[i].up) {
					md(tem[j+k+1][2], dp[u][j][0]*md(dp[v][k][0]+dp[v][k][2])%mod);
					md(tem[j+k+1][3], dp[u][j][1]*md(dp[v][k][2]+dp[v][k][0])%mod);
				}
				else {
					md(tem[j+k+1][1], dp[u][j][0]*md(dp[v][k][0]+dp[v][k][1])%mod);
					md(tem[j+k+1][3], dp[u][j][2]*md(dp[v][k][1]+dp[v][k][0])%mod);
				}
			}
		}
		siz[u]+=siz[v];
		for (int j=0; j<=siz[u]; ++j) 
			for (int k=0; k<4; ++k)
				dp[u][j][k]=tem[j][k];
	}
}

signed main()
{
	memset(head, -1, sizeof(head));
	n=read();
	fac[0]=fac[1]=1;
	for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v, 0); add(v, u, 1);
	}
	dfs(1, 0);
	for (int i=0; i<n; ++i)
		for (int j=0; j<4; ++j)
			md(s[i], dp[1][i][j]);
	//cout<<"s: "; for (int i=0; i<n; ++i) cout<<s[i]<<' '; cout<<endl;
	for (int i=0; i<n; ++i) s[i]=s[i]*fac[n-i]%mod;
	ll ans=0;
	for (int i=0; i<n; ++i)
		ans=(ans+(i&1?-1ll:1ll)*s[i]%mod)%mod;
	printf("%lld\n", (ans%mod+mod)%mod);
		
	return 0;
}
posted @ 2021-09-07 17:00  Administrator-09  阅读(5)  评论(0编辑  收藏  举报