题解 P4099 [HEOI2013]SAO
之前写的是什么鬼……重写了
题意简述
一棵 \(n\) 个点的树,边有方向,求其拓扑序数量,答案对 \(10^9+7\) 取模。
\(n \leq 10^3\)。
题目分析
将拓扑序理解成给每一个点赋一个值 \(t_u\),求满足 \(\forall (u,v),t_u<t_v\) 的排列 \(\{t_i\}\) 。
如果边的方向都是统一的(即第三个部分分),图退化成一棵有根树,可以使用 \(O(n)\) 的树形 dp 解决。
本题中的树边方向不统一,似乎只能将它当作 DAG
来考虑,但这样不仅丢失了许多性质,也难以设计状态。
不妨直接进行树形 dp,对于两种方向分别进行转移。
考虑设计状态,观察到方案数只与值的大小关系有关,设 \(f_{i,j}\) 表示 \(i\) 在其子树中的排名为 \(j\) 的方案数。
转移使用类似树形背包的方式,不断加入子树。
现在考虑将树 \(v\) 加入树 \(u\) 得到新的树 \(u'\),计算 \(f_{u',k}\) 的值。
-
\(v\) 连向 \(u\):
设加入前 \(u\) 在树 \(u\) 内的排名为 \(p\),\(v\) 在树 \(v\) 内的排名为 \(q\)。
-
考虑限制:加入后 \(u\) 的排名变为了 \(k\),说明树 \(v\) 内有 \(k-p\) 个点比 \(u\) 小,而 \(q\) 在树 \(v\) 内的排名为 \(q\),因此有限制 \(q \leq k-p\),即 \(p+q \leq k\)。
-
考虑计算:\(u\) 的排名由 \(p\) 变为 \(k\),比 \(u\) 小的点原来有 \(p-1\) 个,新加的有 \(k-p\) 个,而这些点实际上是确定的(就是 \(u\) 树里最小的 \(p-1\) 个点和 \(v\) 树的 \(k-p\) 个点),因此只需考虑它们间的顺序:可重排列 \(\dbinom{k-1}{p-1}\);同理,比 \(u\) 大的点原来有 \(siz_u-p\) 个,新加的有 \(siz_v-k+p\) 个,可重排列 \(\dbinom{siz_u+siz_v-k}{siz_u-p}\)。
综上,可以得到转移式:
\[f_{u',k}=\sum\limits_{p=1}^{siz_u} \sum\limits_{q=1}^{siz_v} [p+q \leq k] f_{u,p} f_{v,q} \dbinom{k-1}{p-1} \dbinom{siz_u+siz_v-k}{siz_u-p} \] -
-
\(u\) 连向 \(v\):
-
考虑限制:同理,有限制 \(q > k-p\),即 \(p+q > k\)。
-
考虑计算:与第一种情况相同。
同理,可以得到转移式:
\[f_{u',k}=\sum\limits_{p=1}^{siz_u} \sum\limits_{q=1}^{siz_v} [p+q > k] f_{u,p} f_{v,q} \dbinom{k-1}{p-1} \dbinom{siz_u+siz_v-k}{siz_u-p} \] -
总复杂度 \(O(n^4)\)。
考虑优化,将只与 \(p\) 有关的东西提出来:
就只剩下一个 \(f_{v,q}\) 的前缀和,即可做到 \(O(n^2)\),可以通过本题。
\(O(n^2 \log n)\) 的任意模数 NTT 也可以过去……
代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e3+5,mod=1e9+7;
int mul[N],inv[N],pw[N];
int ksm(int a,int x){
int tot=1;
while(x){
if(x & 1ll) tot=1ll*tot*a%mod;
a=1ll*a*a%mod;
x>>=1ll;
}
return tot;
}
void PRE(){
int lim=N-5;
mul[0]=inv[0]=1;
for(int i=1;i<=lim;i++) mul[i]=1ll*mul[i-1]*i%mod;
inv[lim]=ksm(mul[lim],mod-2);
for(int i=lim-1;i>=1;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
}
int C(int m,int n){
if(m<n || m<0 || n<0) return 0;
return 1ll*mul[m]*inv[n]%mod*inv[m-n]%mod;
}
struct nod{
int to,nxt,w;
}e[N*2];
int head[N],cnt;
void add(int u,int v,int w){
e[++cnt].to=v;
e[cnt].w=w;
e[cnt].nxt=head[u];
head[u]=cnt;
}
int T,n,ans;
int siz[N],f[N][N],pre[N][N],suf[N][N],tmp[N];
void dfs(int u,int fa){
siz[u]=1;
f[u][1]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
if(e[i].w==1){
for(int p=siz[u];p>=1;p--)
for(int q=siz[v];q>=1;q--)
tmp[p+q]=(tmp[p+q]+1ll*f[u][p]*C(p+q-1,p-1)%mod*C(siz[u]+siz[v]-p-q,siz[u]-p)%mod*pre[v][q]%mod)%mod;
}
else{
for(int p=siz[u];p>=1;p--)
for(int q=siz[v];q>=0;q--)
tmp[p+q]=(tmp[p+q]+1ll*f[u][p]*C(p+q-1,p-1)%mod*C(siz[u]+siz[v]-p-q,siz[u]-p)%mod*suf[v][q+1]%mod)%mod;
}
siz[u]+=siz[v];
for(int t=1;t<=siz[u];t++) f[u][t]=tmp[t],tmp[t]=0;
}
pre[u][0]=f[u][0];
for(int t=1;t<=siz[u];t++) pre[u][t]=(pre[u][t-1]+f[u][t])%mod;
suf[u][siz[u]+1]=0;
for(int t=siz[u];t>=1;t--) suf[u][t]=(suf[u][t+1]+f[u][t])%mod;
}
int main(){
PRE();
cin>>T;
while(T-->0){
cin>>n;
cnt=0;
memset(head,0,sizeof(head));
memset(f,0,sizeof(f));
ans=0;
for(int i=1;i<n;i++){
int u,v;
char w;
scanf("%d %c%d",&u,&w,&v);
u++,v++;
if(w=='>') add(u,v,1),add(v,u,0);
else add(u,v,0),add(v,u,1);
}
dfs(1,0);
for(int i=1;i<=n;i++) ans=(ans+f[1][i])%mod;
cout<<ans<<endl;
}
}