BZOJ2655 Calc - dp 拉格朗日插值法
题意:
给定n,m,mod,问在对mod取模的背景下,从$[1,m]$中选出n个数相乘可以得到的总和为多少。
思路:
首先可以发现dp方程 ,假定$dp[m][n]$表示从$[1 ~ m]$中选出n个数乘积的和,
那么$$dp[m][n] = dp[m-1][n] + dp[m-1][n-1]*m*n$$。
但是这道题的m有1e9那么大,不能dp完,不过我们可以发现,$dp[x][n]$ 是关于x的2*n多项式,
所以,我们只要先求出0~2*n的dp值,再用拉格朗日插值法算出$dp[m][n]$的即可。
#include <algorithm> #include <iterator> #include <iostream> #include <cstring> #include <iomanip> #include <cstdlib> #include <cstdio> #include <string> #include <vector> #include <bitset> #include <cctype> #include <queue> #include <cmath> #include <list> #include <map> #include <set> using namespace std; //#pragma GCC optimize(3) //#pragma comment(linker, "/STACK:102400000,102400000") //c++ #define lson (l , mid , rt << 1) #define rson (mid + 1 , r , rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << "\n"; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; typedef pair<ll ,ll > pll; typedef pair<int ,int > pii; typedef pair<int ,pii> p3; //priority_queue<int> q;//这是一个大根堆q //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q #define fi first #define se second //#define endl '\n' #define OKC ios::sync_with_stdio(false);cin.tie(0) #define FT(A,B,C) for(int A=B;A <= C;++A) //用来压行 #define REP(i , j , k) for(int i = j ; i < k ; ++i) //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFFLL; //2147483647 const ll nmos = 0x80000000LL; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3fLL; //18 const double PI=acos(-1.0); template<typename T> inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } // #define _DEBUG; //*// #ifdef _DEBUG freopen("input", "r", stdin); // freopen("output.txt", "w", stdout); #endif /*-----------------------show time----------------------*/ const int maxn = 3000; ll dp[maxn][maxn],x[maxn],y[maxn]; int m,n,mod; ll ksm (ll a,ll b){ ll res = 1; while(b>0){ if(b&1) res = (res * a)%mod; a = (a * a)%mod; b >>= 1; } return res; } ll lagerange(int k){ ll res = 0; for(int i=0; i<=2*n; i++){ ll s1=1,s2 = 1; for(int j=0; j<=2*n; j++){ if(i==j)continue; s1 = 1ll*(s1 * (k - x[j] + mod)%mod)%mod; s2 = 1ll*(s2 * ((x[i] - x[j] + mod)%mod))%mod; } res = (res + 1ll*s1 * ksm(s2,mod-2) % mod * y[i] % mod+mod)%mod; } return res; } int main(){ scanf("%d%d%d", &m, &n, &mod); dp[0][0] = 1; for(int i=1; i<=2*n; i++){ dp[i][0] = 1; for(int j=1; j<=n; j++){ dp[i][j] = 1ll*dp[i-1][j-1] * i % mod * j + dp[i-1][j]; dp[i][j] = dp[i][j]%mod; } } if(m <= 2 * n){ printf("%lld\n", dp[m][n]); return 0; } for(int i=1; i<=2*n; i++) x[i] = i,y[i] = dp[i][n]; printf("%lld\n",lagerange(m)); return 0; }
skr