题解 数树
给定一些点之间的约束,求合法方案数
看起来挺套路的,考虑容斥
令 \(s[i]\) 为至少有 \(i\) 个约束条件不满足时的方案数
注意这里的「方案数」指的是选出 \(i\) 条边,令它们不满足约束的选法数
发现如果有两条边的起点或终点是同一个点会炸锅,所以这个选法数不能组合数算
那就树形DP好了
第一个思路是令 \(dp[i][j][0/1/2]\) 表示以 \(i\) 为根的子树,子树内至少有 \(j\) 条边不满足约束,点 \(i\) 与父节点间的边为 不选/向上/向下 时的选法数
但好像没法转移
所以根据题解换成了 \(dp[i][j][0/1/2/3]\) 表示点 \(i\) 没有连边/连了一条入边/连了一条出边/连了一条入边和一条出边 时的选法数
考虑转移(这里是子树合并):
首先如果 \(u\) 与 \(v\) 间的这条边不选,那 \(v\) 的四种情况对 \(u\) 的四种情况都能产生贡献
然后考虑剩下的情况
那转移大概是这样的
然后就是容斥那块
这里求了至少,想要恰好,所以二项式反演就好
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 5010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int head[N], size, sta[N], top, siz[N];
ll dp[N][N][4], tem[N][4], s[N], fac[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll md(ll a) {return a>=mod?a-mod:a;}
struct edge{int to, next; bool up;}e[N<<1];
inline void add(int s, int t, bool v) {e[++size].to=t; e[size].next=head[s]; e[size].up=v; head[s]=size;}
void dfs(int u, int fa) {
siz[u]=1; dp[u][0][0]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs(v, u);
ll sum;
memset(tem, 0, sizeof(tem));
for (int j=siz[u]; ~j; --j) {
for (int k=siz[v]; ~k; --k) {
sum=0;
for (int h=0; h<4; ++h) md(sum, dp[v][k][h]);
for (int h=0; h<4; ++h) md(tem[j+k][h], dp[u][j][h]*sum%mod);
if (!e[i].up) {
md(tem[j+k+1][2], dp[u][j][0]*md(dp[v][k][0]+dp[v][k][2])%mod);
md(tem[j+k+1][3], dp[u][j][1]*md(dp[v][k][2]+dp[v][k][0])%mod);
}
else {
md(tem[j+k+1][1], dp[u][j][0]*md(dp[v][k][0]+dp[v][k][1])%mod);
md(tem[j+k+1][3], dp[u][j][2]*md(dp[v][k][1]+dp[v][k][0])%mod);
}
}
}
siz[u]+=siz[v];
for (int j=0; j<=siz[u]; ++j)
for (int k=0; k<4; ++k)
dp[u][j][k]=tem[j][k];
}
}
signed main()
{
memset(head, -1, sizeof(head));
n=read();
fac[0]=fac[1]=1;
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v, 0); add(v, u, 1);
}
dfs(1, 0);
for (int i=0; i<n; ++i)
for (int j=0; j<4; ++j)
md(s[i], dp[1][i][j]);
//cout<<"s: "; for (int i=0; i<n; ++i) cout<<s[i]<<' '; cout<<endl;
for (int i=0; i<n; ++i) s[i]=s[i]*fac[n-i]%mod;
ll ans=0;
for (int i=0; i<n; ++i)
ans=(ans+(i&1?-1ll:1ll)*s[i]%mod)%mod;
printf("%lld\n", (ans%mod+mod)%mod);
return 0;
}