bzoj 4559 [JLoi2016]成绩比较 —— DP+拉格朗日插值

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4559

看了看拉格朗日插值:http://www.cnblogs.com/ECJTUACM-873284962/p/6833391.html

https://blog.csdn.net/lvzelong2014/article/details/79159346

https://blog.csdn.net/qq_35649707/article/details/78018944

还只会最简单的那种,正好在这道题里可以用到;

计算方案数,可以考虑DP,利用那个所有成绩都小于 B 的性质,枚举超过 B 的一门课;

设计 f[i][j] 表示当前到了第 i 门课,还剩 j 个人被碾压(一开始是所有人都被碾压,然后渐渐突破...);

则 f[i][j] = ∑(j<=t<=n-1) f[i-1][t] * C(n-1-t,rk[i]-1-(t-j)) * C(t,j) * g[i]

其中第一个组合数表示在 n-1-t 个上一次已经不被碾压的人中选出  rk[i]-1-(t-j) 个作为这次成绩高于 B 的人,第二个组合数表示从 t 个上次被碾压的人中选出 j 个这次仍然被碾压(也等同与选出 t-j 个人这次成绩高于 B );

g[i] 则表示在 i 这门课上的成绩分布情况,则选出的人的成绩可以对号入座;

而 g[i] = ∑(1<=j<=lim[i]) j^(n-rk[i]) * (lim[i]-j)^(rk[i]-1),表示若 B 的成绩是 j,则有 n-rk[i] 个人的成绩在 1~j 中选择,有 rk[i]-1 个人的成绩在 lim[i]-j~lim[i] 中选择;

可以发现这是个大约 n+1 次的多项式,所以设出几个点,求出当 x=lim 时的取值即可,这个过程的复杂度是 n^2 的。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=105,mod=1e9+7;
int n,m,K,lm[xn],rk[xn],g[xn],c[xn][xn],f[xn][xn],xx[xn],yy[xn];
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int pw(ll a,int b)
{
  ll ret=1; 
  for(;b;b>>=1,a=(a*a)%mod)
    if(b&1)ret=(ret*a)%mod;
  return ret;
}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
void init()
{
  for(int i=0;i<=n;i++)c[i][0]=1;
  for(int i=1;i<=n;i++)
    for(int j=1;j<=i;j++)c[i][j]=upt(c[i-1][j]+c[i-1][j-1]);
}
int solve(int lim,int n,int m)
{
  int num=n+m+2,sum=0;
  for(int i=1;i<=num;i++)
    xx[i]=i,yy[i]=upt(yy[i-1]+(ll)pw(i,n)*pw(lim-i,m)%mod);
  for(int i=1;i<=num;i++)
    {
      ll s1=1,s2=1;
      for(int j=1;j<=num;j++)
    if(i!=j)//!!!
      s1=s1*(lim-xx[j])%mod,s2=s2*(xx[i]-xx[j])%mod;
      sum=upt(sum+s1*pw(s2,mod-2)%mod*yy[i]%mod);
    }
  return sum;
}
int main()
{
  n=rd()-1; m=rd(); K=rd(); init();//n-1
  for(int i=1;i<=m;i++)lm[i]=rd();
  for(int i=1;i<=m;i++)rk[i]=rd(),g[i]=solve(lm[i],n-rk[i]+1,rk[i]-1);//+1
  f[0][n]=1;//n
  for(int i=1;i<=m;i++)
    for(int j=K;j<=n;j++)//k
      for(int t=j;t<=n;t++)
    {
      if(t-j>rk[i]-1||j>n-rk[i]+1)continue;//+1!
      f[i][j]=upt(f[i][j]+(ll)f[i-1][t]*c[t][j]%mod*c[n-t][rk[i]-1-t+j]%mod*g[i]%mod);
    }
  printf("%d\n",f[m][K]);
  return 0;
}

 

posted @ 2018-11-23 11:53  Zinn  阅读(160)  评论(0编辑  收藏  举报