bzoj 4559 [JLoi2016]成绩比较 —— DP+拉格朗日插值
题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4559
看了看拉格朗日插值:http://www.cnblogs.com/ECJTUACM-873284962/p/6833391.html
https://blog.csdn.net/lvzelong2014/article/details/79159346
https://blog.csdn.net/qq_35649707/article/details/78018944
还只会最简单的那种,正好在这道题里可以用到;
计算方案数,可以考虑DP,利用那个所有成绩都小于 B 的性质,枚举超过 B 的一门课;
设计 f[i][j] 表示当前到了第 i 门课,还剩 j 个人被碾压(一开始是所有人都被碾压,然后渐渐突破...);
则 f[i][j] = ∑(j<=t<=n-1) f[i-1][t] * C(n-1-t,rk[i]-1-(t-j)) * C(t,j) * g[i]
其中第一个组合数表示在 n-1-t 个上一次已经不被碾压的人中选出 rk[i]-1-(t-j) 个作为这次成绩高于 B 的人,第二个组合数表示从 t 个上次被碾压的人中选出 j 个这次仍然被碾压(也等同与选出 t-j 个人这次成绩高于 B );
g[i] 则表示在 i 这门课上的成绩分布情况,则选出的人的成绩可以对号入座;
而 g[i] = ∑(1<=j<=lim[i]) j^(n-rk[i]) * (lim[i]-j)^(rk[i]-1),表示若 B 的成绩是 j,则有 n-rk[i] 个人的成绩在 1~j 中选择,有 rk[i]-1 个人的成绩在 lim[i]-j~lim[i] 中选择;
可以发现这是个大约 n+1 次的多项式,所以设出几个点,求出当 x=lim 时的取值即可,这个过程的复杂度是 n^2 的。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=105,mod=1e9+7; int n,m,K,lm[xn],rk[xn],g[xn],c[xn][xn],f[xn][xn],xx[xn],yy[xn]; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return f?ret:-ret; } int pw(ll a,int b) { ll ret=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1)ret=(ret*a)%mod; return ret; } int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void init() { for(int i=0;i<=n;i++)c[i][0]=1; for(int i=1;i<=n;i++) for(int j=1;j<=i;j++)c[i][j]=upt(c[i-1][j]+c[i-1][j-1]); } int solve(int lim,int n,int m) { int num=n+m+2,sum=0; for(int i=1;i<=num;i++) xx[i]=i,yy[i]=upt(yy[i-1]+(ll)pw(i,n)*pw(lim-i,m)%mod); for(int i=1;i<=num;i++) { ll s1=1,s2=1; for(int j=1;j<=num;j++) if(i!=j)//!!! s1=s1*(lim-xx[j])%mod,s2=s2*(xx[i]-xx[j])%mod; sum=upt(sum+s1*pw(s2,mod-2)%mod*yy[i]%mod); } return sum; } int main() { n=rd()-1; m=rd(); K=rd(); init();//n-1 for(int i=1;i<=m;i++)lm[i]=rd(); for(int i=1;i<=m;i++)rk[i]=rd(),g[i]=solve(lm[i],n-rk[i]+1,rk[i]-1);//+1 f[0][n]=1;//n for(int i=1;i<=m;i++) for(int j=K;j<=n;j++)//k for(int t=j;t<=n;t++) { if(t-j>rk[i]-1||j>n-rk[i]+1)continue;//+1! f[i][j]=upt(f[i][j]+(ll)f[i-1][t]*c[t][j]%mod*c[n-t][rk[i]-1-t+j]%mod*g[i]%mod); } printf("%d\n",f[m][K]); return 0; }