bzoj3167[HEOI2013]SAO
这虽然是游戏,但可不是闹着玩的. ——茅场晶彦
HEOI2013有一道SAO和一道ALO……SwordArtOnline和AlfheimOnline233333
一开始看到这道题,我觉得noip模拟赛考过这道题,然后发现自己想不出来做法,然后开始怀疑自己的DP比联赛之前还弱,然后发现联赛之前考的那道题只有链的情况….
对于只有链的情况,我们不妨认为只有i和i-1号点之间连边,记f[i][j]为前i个点的合法序列中第i个点位于第j个位置的方案数,转移到时候枚举下一个点放在哪里就可以O(n^2).
对于树的情况,我们定义f[i][j]为点i为根的子树的合法序列中,点i位于第j个位置的方案数,然后暴力合并一下,复杂度分析类似于bzoj4033和bzoj4753,就可以做到O(n^2)的时间复杂度.
因为每次合并一棵子树时付出的代价是”已经合并的兄弟子树的大小之和”*”正在合并的这棵子树的大小”,实质上是树上每对节点在LCA处贡献1个时间复杂度
然后我写残了,合并一棵子树付出的代价=”已经合并的兄弟子树的大小之和+正在合并的这棵子树的大小”*”正在合并的这棵子树的大小”,这样复杂度变成O(n^2+所有子树大小的平方之和),例如一条链,从链的一头开始做树形DP,复杂度是n^2+(n-1)^2+(n-2)^2……的,会退化成O(n^3)我试着把链的情况特判掉做O(n^2)的DP,然后还是过不去,看来数据是比较强的.
具体DP的时候,我们依次合并一个节点x的所有儿子的子树,合并完了前k个儿子的时候f[x][i]表示的含义是包括前k个儿子且x位于第i个位置的合法方案数.那么合并第k+1棵子树的时候我们枚举x在前k个儿子的时候的位置在哪里,再枚举第k+1棵子树有几个节点插入到x的前面,用前缀和后缀和优化一下,插入的方案需要预处理组合数,这样复杂度就是O(n^2)的
貌似有些题解的上界不紧,说这个标算是O(n^3)的...讲道理这个O(n^2)复杂度的树形DP的题还是不少的.
#include<cstdio> #include<cstring> #include<cctype> void read(int &x,int &w,int &y){ char ch;while(ch=getchar(),!isdigit(ch));x=ch-'0'; while(ch=getchar(),isdigit(ch))x=x*10+ch-'0'; while(!isgraph(ch))ch=getchar(); if(ch=='<')w=-1; else w=1; while(ch=getchar(),!isdigit(ch));y=ch-'0'; while(ch=getchar(),isdigit(ch))y=y*10+ch-'0'; } const int mod=1000000007; const int maxn=1005; int C[maxn][maxn]; struct edge{ int to,next,w; }lst[maxn<<1];int len=1,first[maxn]; void addedge(int a,int b,int w){ lst[len].to=b;lst[len].next=first[a];lst[len].w=w;first[a]=len++; } int g[maxn]; int f[maxn][maxn]; int pre[maxn][maxn],suf[maxn][maxn]; int sz[maxn],prt[maxn]; void dp(int x){ for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to!=prt[x]){ prt[lst[pt].to]=x;dp(lst[pt].to); } } sz[x]=1;f[x][1]=1;int multF,multC; for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to!=prt[x]){ int lim=sz[x]+sz[lst[pt].to]; for(int i=1;i<=lim;++i)g[i]=0; if(lst[pt].w==1){ for(int i=1;i<=sz[x];++i){ for(int j=1;j<=sz[lst[pt].to];++j){ multF=f[x][i]*1LL*suf[lst[pt].to][sz[lst[pt].to]-j+1]%mod;multC=C[sz[x]-i+j][j]*1LL*C[i+sz[lst[pt].to]-j-1][i-1]%mod; g[i+sz[lst[pt].to]-j]=(g[i+sz[lst[pt].to]-j]+multF*1LL*multC)%mod; } } }else{ for(int i=1;i<=sz[x];++i){ for(int j=1;j<=sz[lst[pt].to];++j){ multF=f[x][i]*1LL*pre[lst[pt].to][j]%mod;multC=C[i-1+j][j]*1LL*C[sz[x]-i+sz[lst[pt].to]-j][sz[lst[pt].to]-j]%mod; g[i+j]=(g[i+j]+multF*1LL*multC)%mod; } } } sz[x]+=sz[lst[pt].to]; for(int i=1;i<=sz[x];++i)f[x][i]=g[i]; } } pre[x][1]=f[x][1];suf[x][sz[x]]=f[x][sz[x]]; for(int i=2;i<=sz[x];++i)pre[x][i]=(f[x][i]+pre[x][i-1])%mod; for(int i=sz[x]-1;i>=1;--i)suf[x][i]=(f[x][i]+suf[x][i+1])%mod; } int main(){ int tests;scanf("%d",&tests); for(int i=0;i<maxn;++i)C[i][0]=1; for(int i=1;i<maxn;++i){ for(int j=1;j<=i;++j)C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod; } int x,y,z,n; while(tests--){ scanf("%d",&n);memset(first,0,sizeof(first));len=1; for(int i=0;i<=n;++i){ for(int j=0;j<=n;++j)f[i][j]=0; } for(int i=1;i<n;++i){ read(x,y,z);addedge(x,z,y);addedge(z,x,-y); } prt[1]=-1;dp(1); int ans=0; for(int i=1;i<=n;++i)ans=(ans+f[1][i])%mod; printf("%d\n",ans); } return 0; }
#include<cstdio> #include<cstring> #include<cctype> void read(int &x,int &w,int &y){ char ch;while(ch=getchar(),!isdigit(ch));x=ch-'0'; while(ch=getchar(),isdigit(ch))x=x*10+ch-'0'; while(!isgraph(ch))ch=getchar(); if(ch=='<')w=-1; else w=1; while(ch=getchar(),!isdigit(ch));y=ch-'0'; while(ch=getchar(),isdigit(ch))y=y*10+ch-'0'; } const int mod=1000000007; const int maxn=1005; int C[maxn][maxn]; struct edge{ int to,next,w; }lst[maxn<<1];int len=1,first[maxn]; void addedge(int a,int b,int w){ lst[len].to=b;lst[len].next=first[a];lst[len].w=w;first[a]=len++; } int sum[maxn]; int g[2][maxn]; int f[maxn][maxn]; int sz[maxn],prt[maxn]; void dp(int x){ for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to!=prt[x]){//printf("!"); prt[lst[pt].to]=x;dp(lst[pt].to);//printf("?"); } //if(x==0)printf("lst[pt].to==%d prt[0]==%d\n",lst[pt].to,prt[0]); } memset(g[0],0,sizeof(g[0])); sz[x]=1;g[0][1]=1;int flag=0; for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to!=prt[x]){//if(x==2)printf("!"); sz[x]+=sz[lst[pt].to];int anti=flag^1; for(int i=0;i<=sz[x];++i)g[anti][i]=0; if(lst[pt].w==1){ int tmp=0; for(int i=1;i<=sz[x];++i){ tmp=0; for(int j=1;j<=sz[lst[pt].to];++j){ if(j<=i){ tmp=(C[i-1][j-1]*1LL*C[sz[x]-i][sz[lst[pt].to]-j+1]%mod*g[flag][i-(j-1)]+tmp)%mod; } g[anti][i]=(g[anti][i]+tmp*1LL*f[lst[pt].to][j])%mod; } } }else{ int tmp=0; for(int i=1;i<=sz[x];++i){ tmp=0; for(int j=sz[lst[pt].to];j>=1;--j){ if((sz[lst[pt].to]-j)<=sz[x]-i){ tmp=(C[sz[x]-i][sz[lst[pt].to]-j]*1LL*C[i-1][j]%mod*g[flag][i-j]+tmp)%mod; } g[anti][i]=(g[anti][i]+tmp*1LL*f[lst[pt].to][j])%mod; } } } flag^=1; } } //printf("x==%d,flag==%d,%d %d %d %d %d\n",x,flag,g[flag][1],g[flag][2],g[flag][3],g[flag][4],g[flag][5]); for(int i=1;i<=sz[x];++i)f[x][i]=g[flag][i]; } int deg[maxn]; bool is_line(int n){ int cnt1=0,cnt2=0; for(int i=0;i<n;++i){ if(deg[i]>2)return false; if(deg[i]==1)cnt1++; else cnt2++; } return cnt1==2&&cnt2==n-2; } int pre[maxn][maxn],suf[maxn][maxn]; void dfs(int x,int p){ for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to==p)continue; dfs(lst[pt].to,x); sz[x]=sz[lst[pt].to]+1; if(lst[pt].w==1){ for(int i=1;i<sz[x];++i){ f[x][i]=suf[lst[pt].to][i]; } f[x][sz[x]]=0; }else{ for(int i=2;i<=sz[x];++i){ f[x][i]=pre[lst[pt].to][i-1]; } f[x][1]=0; } } pre[x][1]=f[x][1];suf[x][sz[x]]=f[x][sz[x]]; for(int i=2;i<=sz[x];++i)pre[x][i]=(pre[x][i-1]+f[x][i])%mod; for(int i=sz[x]-1;i>=1;--i)suf[x][i]=(suf[x][i+1]+f[x][i])%mod; } void calc_line(int n){ int start; for(int i=0;i<n;++i){ if(deg[i]==1)start=i; } dfs(start,-1); int ans=0; for(int i=1;i<=n;++i){ ans=(ans+f[start][i])%mod; } printf("%d\n",ans); } int main(){ int tests;scanf("%d",&tests); for(int i=0;i<maxn;++i)C[i][0]=1; for(int i=1;i<maxn;++i){ for(int j=1;j<=i;++j)C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod; } int x,y,z,n; while(tests--){ scanf("%d",&n);memset(first,0,sizeof(first));len=1; memset(f,0,sizeof(f));memset(deg,0,sizeof(deg)); for(int i=1;i<n;++i){ read(x,y,z);addedge(x,z,y);addedge(z,x,-y);deg[x]++;deg[z]++; } if(n<=2)printf("1\n"); else if(is_line(n)){ calc_line(n); } else{ prt[1]=-1;dp(1); // for(int i=0;i<n;++i)printf("%d ",sum[i]); int ans=0; for(int i=1;i<=n;++i)ans=(ans+f[1][i])%mod; printf("%d\n",ans); } } return 0; }