BZOJ4332: JSOI2012 分零食
BZOJ4332: JSOI2012 分零食
https://lydsy.com/JudgeOnline/problem.php?id=4332
分析:
- 好题,我们做一个\(dp\), \(g[n][m]\)表示\(n\)个人分\(m\)个零食的答案。
- \(g[n][m]=\sum\limits_{i=1}^{m}g[n-1][m-i]\times f(i)\)
- \(g[n]=g[n-1]\times f\)
- \(g[n]=g[0]\times f^n=f^n\)。
- 我们设\(F[n]=\sum\limits_{i=1}^{n}g[i]\)
- 答案就是\(F[n][m]\)
- \(F[n]=F[n/2]+F[n/2]\times g[n/2]\)
- 分治即可。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef double f2;
typedef long double f3;
#define N 50050
int n,m,p;
const f2 pi=acos(-1);
struct cp {
f2 x,y;
cp() {}
cp(f2 x_,f2 y_) {x=x_,y=y_;}
cp operator + (const cp &u) const {
return cp(x+u.x, y+u.y);
}
cp operator - (const cp &u) const {
return cp(x-u.x, y-u.y);
}
cp operator * (const cp &u) const {
return cp(x*u.x-y*u.y, x*u.y+y*u.x);
}
}F[N],g[N],A[N],f[N];
void fft(cp *a,int len,int flg) {
int i,j,k,t;
cp w,wn,tmp;
for(i=k=0;i<len;i++) {
if(i>k) swap(a[i],a[k]);
for(j=len>>1;(k^=j)<j;j>>=1) ;
}
for(k=2;k<=len;k<<=1) {
t=k>>1;
wn=cp(cos(2*pi*flg/k),sin(2*pi*flg/k));
for(i=0;i<len;i+=k) {
w=cp(1,0);
for(j=i;j<i+t;j++) {
tmp=a[j+t]*w;
a[j+t]=a[j]-tmp;
a[j]=a[j]+tmp;
w=w*wn;
}
}
}
if(flg==-1) {
for(i=0;i<len;i++) a[i].x/=len,a[i].x=int(a[i].x+0.1)%p,a[i].y=0;
}
}
int all;
void solve(cp *f,cp *g,int len) {
if(!len) {g[0]=cp(1,0); return ;}
int i;
if(len&1) {
solve(f,g,len-1);
fft(g,all,1);
for(i=0;i<all;i++) g[i]=g[i]*F[i];
fft(g,all,-1);
for(i=0;i<=m;i++) f[i]=f[i]+g[i],f[i].x=int(f[i].x+0.1)%p,f[i].y=0;
for(i=m+1;i<all;i++) g[i]=cp(0,0);
}else {
solve(f,g,len/2);
// memset(A,0,sizeof(A));
for(i=0;i<all;i++) A[i]=f[i];
fft(A,all,1); fft(g,all,1);
for(i=0;i<all;i++) A[i]=A[i]*g[i];
fft(A,all,-1);
for(i=0;i<all;i++) g[i]=g[i]*g[i];
fft(g,all,-1);
for(i=0;i<=m;i++) f[i]=f[i]+A[i],f[i].x=int(f[i].x+0.1)%p,f[i].y=0;
for(i=m+1;i<all;i++) g[i]=cp(0,0);
}
// printf("%.2f\n",f[m].x);
}
int main() {
int o,s,u;
scanf("%d%d%d%d%d%d",&m,&p,&n,&o,&s,&u);
int len=1,i;
for(;len<=(m<<1);len<<=1) ; all=len;
for(i=1;i<=m;i++) {
F[i].x=(o*i*i+s*i+u)%p;
}
fft(F,len,1);
solve(f,g,n);
printf("%d\n",int(f[m].x+0.1)%p);
}