[集训队互测]calc
题目
点这里看题目。
分析
首先不难想到可以枚举递增的序列,最后在答案里面乘上\(n!\),于是有\(O(nk)\)的暴力 DP 一枚:
\(f(i,j)\)表示长度为\(i\)、最大值\(\le j\)的序列的贡献和。
转移显然:
\[f(i,j)=j\times f(i-1,j-1)+f(i,j-1)
\]
那么可以发现,当序列长度固定的时候,\(f(n,x)\)肯定是关于\(x\)的函数。环顾四周,DP 转移方程中并不存在除法、开方、作为指数乘方等运算,所以可以推测\(f(n,x)\)就是\(x\)的多项式函数。
那么,它的次数是多少呢?这直接决定了我们如何进行插值。设\(f(n,x)\)的次数为\(g(n)\),考虑到转移左右两边的次数应该是相等的,就有:
\[f(i,j)-f(i,j-1)=j\times f(i-1,j-1)\Rightarrow g(n)-1=g(n-1)+1
\]
补充一下,多项式函数做差分,即\(f(x)-f(x-1)\),得到的结果的次数会比原多项式的小一,可以直接用二项式定理展开证明。
然后发现,\(g(n)=g(n-1)+2\),由于\(f(0,x)=1\),所以\(g(0)=0\),得到通项公式\(g(n)=2n\)。
然后我们就知道了\(f(n,x)\)是关于\(x\)的\(2n\)次的多项式函数,因此,我们需要算出\(2n+1\)个点值,用于插值。总时间\(O(n^2)\)。
//f(i,j)=f(i-1,j-1)*j+f(i,j-1)
#include <cstdio>
const int MAXN = 1005;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXN], y[MAXN];
int N, K, M, mod;
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
int inver( const int a ) { return qkpow( a, mod - 2 ); }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int Lagrange()
{
if( K <= M ) return y[K];
int ans = 0, tmp;
for( int i = 1 ; i <= M ; i ++ )
{
tmp = 1;
for( int j = 1 ; j <= M ; j ++ )
if( i != j )
tmp = 1ll * tmp * ( K - j ) % mod * inver( i - j + mod ) % mod;
add( ans, 1ll * tmp * y[i] % mod );
}
return ans;
}
int main()
{
read( K ), read( N ), read( mod );
M = 2 * N + 1;
for( int j = 0 ; j <= M ; j ++ )
f[0][j] = 1;
for( int i = 1 ; i <= N ; i ++ )
for( int j = 1 ; j <= M ; j ++ )
f[i][j] = ( 1ll * f[i - 1][j - 1] * j % mod + f[i][j - 1] ) % mod;
for( int i = 0 ; i <= M ; i ++ ) y[i] = f[N][i];
int fac = 1;
for( int i = 1 ; i <= N ; i ++ ) fac = 1ll * fac * i % mod;
write( 1ll * fac * Lagrange() % mod ), putchar( '\n' );
return 0;
}