P4099 [HEOI2013]SAO(树形dp)
我们设$f[u][k]$表示以拓扑序编号为$k$的点$u$,以$u$为根的子树中的元素所组成的序列方案数
蓝后我们在找一个以$v$为根的子树。
我们的任务就是在合并这两棵树时维护$f[u][k]$
合并时,$v$的元素可能全在点$u$的前/后面,也可能都有。
分类讨论:
1.当有$p(p\in [0,siz[v]])$个元素插入到点$u$(拓扑序)前面时
我们知道插入后点$u$的拓扑序为$k$
那么插入前的拓扑序即为$k-p$
∴插入前子树$u$对应的状态就是$f[u][k-p]$
设$i=k-p$
那么插入的方案数就等价于在$k-1$位置($u$是固定最后的)中选择$i-1$个位置$*f[u][i]$
$=C(k-1,i-1)*f[u][i]$
其中$i\in [1,min(siz[u],k)]$
2.当有$p(p\in [0,siz[v]])$个元素插入到点$u$(拓扑序)后面时
可以这样表示$p=siz[v]-(k-i)=siz[v]-k+i$
而原来$u$(拓扑序)后面的元素共有$siz[u]-i$个
∴方案数$=C(siz[u]-i+p,p)*f[v][j]=C(siz[u]+siz[v]-k,siz[u]-i)*f[v][j]$
$j$的范围需要分类:
- $u<v$(拓扑序)$j\in[k-i+1,siz[v]]$,即点$v$不能填充到拓扑序$<u$的地方
- $u>v$(拓扑序)$j\in[1,k-i]$
整理一下:
$f[u][k]=\sum_{i=1}^{min(siz[u],k)}\sum_{j}(分类讨论)*C(k-1,i-1)*C(siz[u]+siz[v]-k,siz[u]-i)$
然鹅这是$O(n^{3})$
发现是$j$是连续一段区间,可以前缀和处理
蓝后就愉快地变成$O(n^{2})$了
end.
#include<iostream> #include<cstdio> #include<cstring> #include<cctype> #define re register using namespace std; typedef long long ll; void read(int &x){ char c=getchar(); x=0; while(!isdigit(c)) c=getchar(); while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); }int wt[50]; void output(ll x){ if(!x) {putchar(48); return;} int l=0; while(x) wt[++l]=x%10,x/=10; while(l) putchar(wt[l--]+48); } int min(int &a,int &b) {return a<b?a:b;} const int p=1e9+7; ll mod(ll a){return a<p?a:a-p;} #define N 1001 int t,n,siz[N]; ll f[N][N],C[N][N],sum[N][N],ans; int cnt,hd[N],nxt[N<<1],ed[N],poi[N<<1],dir[N<<1]; bool vis[N]; void adde(int x,int y,int v){ nxt[ed[x]]=++cnt; hd[x]=hd[x]? hd[x]:cnt; ed[x]=cnt; poi[cnt]=y; dir[cnt]=v; } void clears(){//清空数据 for(re int i=1;i<=n;++i){ f[i][1]=1; sum[i][1]=0; vis[i]=0; siz[i]=1; hd[i]=ed[i]=0; nxt[i]=nxt[i+n]=0; for(re int j=2;j<=n;++j) f[i][j]=sum[i][j]=0; }cnt=ans=0; } void dfs(int u){ vis[u]=1; for(int z=hd[u];z;z=nxt[z]){ int v=poi[z]; if(vis[v]) continue; dfs(v); if(dir[z]){//u<v(拓扑序) for(int k=siz[u]+siz[v];k>=1;--k){ ll tmp=0; for(int i=min(siz[u],k);i>=1;--i){ int l=k-i+1,r=siz[v]; //下面只要修改l,r即可 if(l>r) continue; ll r1=mod(sum[v][r]-sum[v][l-1]+p); ll r2=C[k-1][i-1]*C[siz[u]+siz[v]-k][siz[u]-i]%p; tmp=mod(tmp+f[u][i]*r1%p*r2%p); }f[u][k]=tmp; } }else{//u>v(拓扑序) for(int k=siz[u]+siz[v];k>=1;--k){ ll tmp=0; for(int i=min(siz[u],k);i>=1;--i){ int l=1,r=k-i; if(l>r) continue; ll r1=mod(sum[v][r]-sum[v][l-1]+p); ll r2=C[k-1][i-1]*C[siz[u]+siz[v]-k][siz[u]-i]%p; tmp=mod(tmp+f[u][i]*r1%p*r2%p); }f[u][k]=tmp; } }siz[u]+=siz[v]; } for(int i=1;i<=siz[u];++i)//维护前缀和 sum[u][i]=mod(sum[u][i-1]+f[u][i]); } int main(){ read(t); char opt[3]; int q1,q2,q3; for(re int i=0;i<N;++i)//组合数预处理 for(re int j=0;j<=i;++j) C[i][j]=(!j||j==i)?1:mod(C[i-1][j]+C[i-1][j-1]); while(t--){ read(n); clears(); for(re int i=1;i<n;++i){ read(q1); scanf("%s",opt); read(q2); q3=(opt[0]=='<'); ++q1; ++q2; adde(q1,q2,q3); adde(q2,q1,q3^1); }dfs(1); for(re int i=1;i<=n;++i) ans=mod(ans+f[1][i]); output(ans); putchar('\n'); }return 0; }