【BZOJ2655】calc DP 数学 拉格朗日插值
题目大意
一个序列\(a_1,\ldots,a_n\)是合法的,当且仅当:
长度为给定的\(n\)。
\(a_1,\ldots,a_n\)都是\([1,m]\)中的整数。
\(a_1,\ldots,a_n\)互不相等。
一个序列的值定义为它里面所有数的乘积,即\(a_1\times a_2\times\cdots\times a_n\)。
求所有不同合法序列的值的和。
两个序列不同当且仅当他们任意一位不一样。
输出答案对一个数\(p\)取余的结果。
\(n\leq500,m\leq {10}^9,p\leq{10}^9,n+1<m<p\)且\(p\)是质数。
题解
这题做法很多种。
设\(f_{i,j}\)为前\(i\)个数中选\(j\)个数的所有方案的值的和,容易得到递推式:\(f_{0,0}=1,f_{i,j}=f_{i-1,j-1}\times i\times j+f_{i-1,j}\)。最后\(ans=f_{m,n}\)。但是这题\(m\)很大,不能直接求出答案。怎么办呢?
我们先打个表:
\(f\) | \(0\) | \(1\) | \(2\) |
---|---|---|---|
\(0\) | \(1\) | \(0\) | \(0\) |
\(1\) | \(1\) | \(1\) | \(0\) |
\(2\) | \(1\) | \(3\) | \(4\) |
\(3\) | \(1\) | \(6\) | \(22\) |
\(4\) | \(1\) | \(10\) | \(70\) |
\(5\) | \(1\) | \(15\) | \(170\) |
\(6\) | \(1\) | \(21\) | \(350\) |
什么?你看不出来?
\(f\) | \(0\) | \(1\) | \(2\) |
---|---|---|---|
\(0\) | \(1\) | \(0\) | \(0\) |
\(1\) | \(1\) | \(i\) | \(0\) |
\(2\) | \(1\) | \(2i-1\) | \(2i^2-2i\) |
\(3\) | \(1\) | \(3i-3\) | \(6i^2-12i+4\) |
\(4\) | \(1\) | \(4i-6\) | \(12i^2-36i+22\) |
\(5\) | \(1\) | \(5i-10\) | \(20i^2-80i+70\) |
\(6\) | \(1\) | \(6i-15\) | \(30i^2-150i+170\) |
你还是看不出来?那我就直接告诉你吧。\(f_{i,0}=1,f_{i,1}=\frac12i^2-\frac12i,f_{i,2}=\frac14i^4+\frac16i^3-\frac14i^2-\frac16i\)。我们会发现,\(f_{i,j}\)是一个最高次项为\(2j\)的多项式,也就是说,\(f_{m,n}\)是一个最高次项为\(2n\)的多项式。我们只用求出\(0\)到\(2n\)次项的系数就可以求答案了。我们可以把前面\(0\)~\(2n\)个\(f_{i,n}\)求出来,就可以用拉格朗日插值插出多项式了。
这道题因为是求某一个点的值,并不要求求出多项式,而且\(x\)取的是\([0,2n]\),所以可以\(O(n)\)求出答案。然而并没有什么用,因为前面的DP已经是\(O(n^2)\)的了。
时间复杂度:\(O(n^2)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p;
ll f[1010][1010];
ll fp(ll a,ll b)
{
ll s=1;
while(b)
{
if(b&1)
s=s*a%p;
a=a*a%p;
b>>=1;
}
return s;
}
int main()
{
int n,m;
scanf("%d%d%lld",&m,&n,&p);
int i,j;
memset(f,0,sizeof f);
f[0][0]=1;
for(i=1;i<=2*n;i++)
{
f[i][0]=f[i-1][0];
for(j=1;j<=n;j++)
f[i][j]=(f[i-1][j-1]*i%p*j+f[i-1][j])%p;
}
if(m<=2*n)
{
printf("%lld\n",f[m][n]);
return 0;
}
ll ans=0;
for(i=0;i<=2*n;i++)
{
ll s1=1,s2=1;
for(j=0;j<=2*n;j++)
if(j!=i)
{
s1=(s1*(m-j))%p;
s2=(s2*(i-j))%p;
}
ans=(ans+f[i][n]*s1%p*fp(s2,p-2)%p)%p;
}
ans=(ans%p+p)%p;
printf("%lld\n",ans);
return 0;
}