[BZOJ3167][HEOI2013]SAO[树dp+组合数学]
题意
给定 \(n\) 个节点和 \(n-1\) 个限制,每个节点有一个权值,每个限制形如:\(a_i< a_j\) ,问有多少个 \(1\) 到 \(n\) 排列满足要求。
\(n\leq 1000\) 。
分析
-
猜测复杂度为 \(O(n^2)\) ,并且应该要看成是树形结构。
-
定义状态 \(f_{i,j}\) 表示以 \(i\) 为根的子树内有 \(j\) 个节点权值 \(< a_i\)的合法方案数。
-
考虑转移,记 \(v\) 为 \(u\) 的儿子,有两种情况:
-
\(a_u > a_v\)
\[f_{u,j+k}=\binom{j+k}{k}\binom{{son}_u+{son}_v-1-j-k}{{son}_v-k}*f_{u,j}*\sum_{x=0}^{k-1}f_{v,x}
\]
- \(a_u < a_v\)
\[f_{u,j+k}=\binom{j+k}{k}\binom{{son}_u+{son}_v-1-j-k}{{son}_v-k}*f_{u,j}*\sum_{x=k}^{{son}_v-1}f_{v,x}
\]
- 总时间复杂度为 \(O(n^2)\) 。
代码
#include<bits/stdc++.h>
using namespace std;
#define go(u) for(int i=head[u],v=e[i].to;i;i=e[i].last,v=e[i].to)
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define pb push_back
#define re(x) memset(x,0,sizeof x)
typedef long long LL;
inline int gi(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-48;ch=getchar();}
return x*f;
}
template<typename T>inline bool Max(T &a,T b){return a<b?a=b,1:0;}
template<typename T>inline bool Min(T &a,T b){return b<a?a=b,1:0;}
const int N=1e3 + 7,mod=1e9 + 7;
int T,n,edc;
int head[N],c[N][N],f[N][N],s[N],son[N],g[N];
char str[10];
struct edge{
int last,to,c;
edge(){}edge(int last,int to,int c):last(last),to(to),c(c){}
}e[N*2];
void Add(int a,int b,int c){
e[++edc]=edge(head[a],b,c),head[a]=edc;
e[++edc]=edge(head[b],a,c^1),head[b]=edc;
}
void add(int &a,int b){a+=b;if(a>=mod) a-=mod;}
void init(){
re(head),re(f);edc=0;
}
void dfs(int u,int fa){
son[u]=f[u][0]=1;
go(u)if(v^fa){
dfs(v,u); re(s);re(g);
if(e[i].c==0){
for(int x=0;x<son[v];++x) s[x]=((x?s[x-1]:0)+f[v][x])%mod;
for(int tv=son[u]-1;~tv;--tv)
for(int b=son[v];b;--b)
add(g[tv+b],1ll*c[tv+b][b]*c[son[u]+son[v]-1-tv-b][son[v]-b]%mod*f[u][tv]%mod*s[b-1]%mod);
}else{
for(int x=son[v]-1;~x;--x) s[x]=(s[x+1]+f[v][x])%mod;
for(int tv=son[u]-1;~tv;--tv)
for(int b=son[v]-1;~b;--b)
add(g[tv+b],1ll*c[tv+b][b]*c[son[u]+son[v]-1-tv-b][son[v]-b]%mod*f[u][tv]%mod*s[b]%mod);
}
memcpy(f[u],g,sizeof g);
son[u]+=son[v];
}
}
void work(){
n=gi();
for(int i=1,a,b,c;i<n;++i){
scanf("%d%s%d",&a,str,&b);
c=str[0]=='>'?0:1;
Add(a+1,b+1,c);
}
dfs(1,0);int ans=0;
rep(i,0,n) add(ans,f[1][i]);
printf("%d\n",ans);
}
int main(){
rep(i,0,N-1){
c[i][0]=1;
rep(j,1,i) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}
T=gi();
while(T--) init(),work();
return 0;
}