bzoj 2616 SPOJ PERIODNI——笛卡尔树+树形DP
题目:https://www.lydsy.com/JudgeOnline/problem.php?id=2616
把相同高度的连续一段合成一个位置(可能不需要?),用前缀和维护宽度。
然后每次找区间里最低的那个点(ST表)作为根,递归左右孩子,构建笛卡尔树。
dp[ cr ][ j ] 表示在 cr 的子树里选择 j 个点的方案数。
自己本来写的是同时枚举 cr 这个点、ls 、rs 各贡献了多少个车,结果TLE。
看看题解,发现这样比较好(至多 \( n^3 \) ),就是先 \( dp[ cr ][ j ] = \sum dp[ ls ][ k ] * dp[ rs ][ j-k ] ),然后再枚举 cr 的贡献,形如 \( dp[ cr ][ j ] = \sum dp[ cr ][ j-k ] * C_{h}^{k} * C_{w-(j-k)}^{k} * k! \) ,其中 w 表示 cr 这个点的宽,h 表示 cr 这个点的高。
注意那里还要乘一个 \( k! \) 。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int Mn(int a,int b){return a<b?a:b;} int Mx(int a,int b){return a>b?a:b;} const int N=505,K=10,M=1e6+5,mod=1e9+7; int upt(int x){if(x>=mod)x-=mod;if(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,ht[N],jc[M],jcn[M],dp[N][N]; int bin[K],lg[N],st[N][K],s[N]; struct Dt{ int bh,mx; Dt(int b=0,int m=0):bh(b),mx(m) {} }; void init(int mx) { jc[0]=1;for(int i=1;i<=mx;i++)jc[i]=(ll)jc[i-1]*i%mod; jcn[mx]=pw(jc[mx],mod-2); for(int i=mx-1;i>=0;i--)jcn[i]=(ll)jcn[i+1]*(i+1)%mod; for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1; bin[0]=1;for(int i=1;i<=lg[n];i++)bin[i]=bin[i-1]<<1; for(int i=1;i<=n;i++)st[i][0]=i; for(int t=1;t<=lg[n];t++) for(int i=1;i+bin[t]-1<=n;i++) { int u=st[i][t-1], v=st[i+bin[t-1]][t-1]; if(ht[u]<ht[v])st[i][t]=u; else st[i][t]=v; } } int C(int n,int m) { if(n<m)return 0;// return (ll)jc[n]*jcn[m]%mod*jcn[n-m]%mod; } int get(int l,int r) { int d=lg[r-l+1]; int u=st[l][d], v=st[r-bin[d]+1][d]; if(ht[u]<ht[v])return u; else return v; } Dt solve(int l,int r,int pr) { if(l>r)return Dt(0,0); int cr=get(l,r), w=s[r]-s[l-1], h=ht[cr]-pr; Dt Ls=solve(l,cr-1,ht[cr]); int ls=Ls.bh,m1=Ls.mx; Dt Rs=solve(cr+1,r,ht[cr]); int rs=Rs.bh,m2=Rs.mx; for(int i=1,l1=m1+m2;i<=l1;i++) for(int j=Mx(0,i-m2),l2=Mn(i,m1);j<=l2;j++) { dp[cr][i]=(dp[cr][i]+(ll)dp[ls][j]*dp[rs][i-j])%mod; } int lm=Mn(h,w), mx=m1+m2+Mn(h,w-m1-m2); dp[cr][0]=1; for(int i=mx;i;i--) for(int j=1,l1=Mn(i,lm);j<=l1;j++) { dp[cr][i]=(dp[cr][i]+ (ll)dp[cr][i-j]*C(h,j)%mod*C(w-i+j,j)%mod*jc[j])%mod; } return Dt(cr,mx); /* printf("(%d,%d)cr=%d w=%d h=%d\n",l,r,cr,w,h); printf(" ls=%d m1=%d rs=%d m2=%d\n",ls,m1,rs,m2); int mx=m1+m2+Mn(h,w-m1-m2); printf(" mx=%d\n",mx); for(int i=1;i<=mx;i++) { printf(" i=%d\n",i); for(int j1=0,l1=Mn(m1,i);j1<=l1;j1++) for(int j2=Mx(0,i-j1-Mn(h,w-j1)),l2=Mn(i-j1,m2);j2<=l2;j2++) { int ret=(ll)dp[ls][j1]*dp[rs][j2]%mod; int k=i-j1-j2; ret=(ll)ret*C(h,k)%mod*C(w-j1-j2,k)%mod*jc[k]%mod;//jc[k] dp[cr][i]=upt(dp[cr][i]+ret); printf(" j1=%d j2=%d k=%d (dp[%d]=%d)\n" ,j1,j2,k,i,dp[cr][i]); } printf(" dp[%d]=%d\n",i,dp[cr][i]); } dp[cr][0]=1; return Dt(cr,mx); */ } int main() { int tn,tm,mx=0; scanf("%d%d",&tn,&tm); for(int i=1,d,lst=0;i<=tn;i++,lst=d) { scanf("%d",&d); mx=Mx(mx,d); if(d!=lst) ht[++n]=d, s[n]=s[n-1]+1; else s[n]++; } init(mx); dp[0][0]=1;/// Dt Rt=solve(1,n,0); printf("%d\n",dp[Rt.bh][tm]); return 0; }