多项式做题笔记
(早上好,笔记在注释里。)
多项式卷积模板:
FFT:
#include<iostream> #include<cstdio> #include<cmath> using namespace std; const int N=3e6+10; double pi=acos(-1); int n,m; struct node{ double x,y; node(double a=0,double b=0){ x=a,y=b; } node operator + (node const &u) const{ return node(x+u.x,y+u.y); } node operator - (node const &u) const{ return node(x-u.x,y-u.y); } node operator * (node const &u) const{ return node(x*u.x-y*u.y,y*u.x+x*u.y); } }f[N]; int pos[N]; void fft(node *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; node fir(cos(2*pi/p),sin(2*pi/p)); if(!flag)fir.y*=-1; for(int k=0;k<n;k+=p){ node buf(1,0); for(int l=k;l<k+len;l++){ node tt=buf*f[len+l]; f[len+l]=f[l]-tt; f[l]=f[l]+tt; buf=buf*fir; } } } } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++)scanf("%lf",&f[i].x); for(int i=0;i<=m;i++)scanf("%lf",&f[i].y); for(m+=n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?n>>1:0); } fft(f,1); for(int i=0;i<n;i++)f[i]=f[i]*f[i]; fft(f,0); for(int i=0;i<=m;i++)printf("%d ",(int)(f[i].y/n/2+0.5)); return 0; } //FFT比较丢精度,如果需要卷积的多项式系数值域相差太大,就会卡精度 //三次变两次优化涉及的精度跨度上限更大,严重掉精度
#include<iostream> #include<cstdio> using namespace std; const int N=3e6+10,mod=998244353,G=3; int n,m,pos[N]; long long f[N],g[N],invn,invG; long long pw(long long x,long long k){ long long num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(long long *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; long long fir=pw((flag?G:invG),(mod-1)/p); for(int i=0;i<n;i+=p){ long long bur=1; for(int l=i;l<i+len;l++){ long long tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++)scanf("%lld",&f[i]); for(int i=0;i<=m;i++)scanf("%lld",&g[i]); for(m+=n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); } invn=pw(n,mod-2),invG=pw(G,mod-2); ntt(f,1),ntt(g,1); for(int i=0;i<n;i++){ f[i]=f[i]*g[i]%mod; } ntt(f,0); for(int i=0;i<=m;i++){ printf("%lld ",f[i]*invn%mod); } return 0; } //数组需要开到2的幂次以上,不是两倍 //第一个单位根从1开始 //mod的大小开在最大系数以上防止模掉
#include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=3e5+10,mod=998244353,G=3; int n,pos[N]; ll a[N],b[N],c[N],invG; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(ll *f,int n,bool flag){ for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=n;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ ll bur=1; for(int l=k;l<len+k;l++){ ll tt=f[l+len]*bur%mod; f[len+l]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } void getinv(int now,ll *a,ll*b){ if(now==1){b[0]=pw(a[0],mod-2);return;} getinv((now+1)>>1,a,b); int goal=1; while(goal<(now<<1))goal<<=1; ll invn=pw(goal,mod-2); for(int i=0;i<goal;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(goal>>1):0); for(int i=0;i<now;i++)c[i]=a[i]; for(int i=now;i<goal;i++)c[i]=0; ntt(c,goal,1),ntt(b,goal,1); for(int i=0;i<goal;i++)b[i]=((2ll-c[i]*b[i]%mod)%mod+mod)%mod*b[i]%mod; ntt(b,goal,0); for(int i=0;i<now;i++)b[i]=b[i]*invn%mod; for(int i=now;i<goal;i++)b[i]=0; } int main(){ scanf("%d",&n); for(int i=0;i<n;i++)scanf("%lld",&a[i]); invG=pw(G,mod-2); getinv(n,a,b); for(int i=0;i<n;i++)printf("%lld ",b[i]); return 0; } //递归每一层本质对x^now取模,逆元数组b每次处理完要把被模掉的多余部分清空。
一些题目:
快速傅里叶之二:
#include<iostream> #include<cstdio> using namespace std; const int N=3e5+10,G=3; const long long mod=2281701377; int n,m,pos[N]; long long a[N],b[N],invG,invn; long long pw(long long x,long long k){ long long num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(long long *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; long long fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ long long bur=1; for(int l=k;l<k+len;l++){ long long tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main() { scanf("%d",&n); for(int i=0;i<n;i++)scanf("%lld%lld",&a[i],&b[i]); for(int i=1;i<=n/2;i++)swap(b[i-1],b[n-i]); for(m=n+n,n=1;n<=m-2;n<<=1); invn=pw(n,mod-2),invG=pw(G,mod-2); for(int i=1;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); } ntt(a,1),ntt(b,1); for(int i=0;i<n;i++){ a[i]=a[i]*b[i]%mod; } ntt(a,0); for(int i=m/2-1;i<=m-2;i++){ printf("%lld\n",a[i]*invn%mod); } return 0; }
#include<iostream> #include<cstdio> #include<cmath> using namespace std; const int N=3e5+10; double pi=acos(-1); int n,m,pos[N]; struct node{ double x,y; node(double xx=0,double yy=0){ x=xx,y=yy; } node operator + (node const &u) const{ return node(x+u.x,y+u.y); } node operator - (node const &u) const{ return node(x-u.x,y-u.y); } node operator * (node const &u) const{ return node(x*u.x-y*u.y,y*u.x+x*u.y); } }f[N],g[N],f0[N]; void fft(node *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; node fir=node(cos(2*pi/p),sin(2*pi/p)); if(!flag)fir.y=-fir.y; for(int k=0;k<n;k+=p){ node bur=node(1,0); for(int l=k;l<len+k;l++){ node tt=bur*f[len+l]; f[len+l]=f[l]-tt; f[l]=f[l]+tt; bur=bur*fir; } } } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%lf",&f[i].x); f0[n-i+1].x=f[i].x; g[i].x=1.0/i/i;//(i*i)炸int } for(m=n+n,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?n>>1:0); } fft(f,1),fft(g,1),fft(f0,1); for(int i=0;i<n;i++){ f[i]=f[i]*g[i]; f0[i]=f0[i]*g[i]; } fft(f,0),fft(f0,0); for(int i=1;i<=m/2;i++){ printf("%.3lf\n",(f[i].x-f0[m/2-i+1].x)/n); } return 0; } //反转是基操
Tyvj1953 Normal:
#include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=200010,mod=998244353,G=3; int nn,n,root,siz[N],vis[N],f[N],sum,dis[N],pos[N]; int ver[2*N],head[N],Next[2*N],tot,maxx; ll A[N],invG,invn,ans[N]; long double ans1; void add(int x,int y){ ver[++tot]=y; Next[tot]=head[x]; head[x]=tot; } ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void findroot(int x,int fa){ siz[x]=1,f[x]=0; for(int i=head[x];i;i=Next[i]){ int y=ver[i]; if(vis[y]||y==fa)continue; findroot(y,x); siz[x]+=siz[y]; f[x]=max(f[x],siz[y]); } f[x]=max(f[x],sum-siz[x]); if(f[x]<f[root])root=x; } void ntt(ll *f,bool flag){ for(int i=0;i<nn;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=nn;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<nn;k+=p){ ll bur=1; for(int l=k;l<k+len;l++){ ll tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } void getdis(int x,int fa){ A[dis[x]]++; maxx=max(maxx,dis[x]); for(int i=head[x];i;i=Next[i]){ int y=ver[i]; if(vis[y]||y==fa)continue; dis[y]=dis[x]+1; getdis(y,x); } } void cal(int x,int lon,int val){ dis[x]=lon; for(int i=0;i<nn;i++)A[i]=0,pos[i]=0; maxx=0; getdis(x,0); for(nn=1;nn<=(maxx*2);nn<<=1); for(int i=0;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0); invn=pw(nn,mod-2); ntt(A,1); for(int i=0;i<nn;i++)A[i]=A[i]*A[i]%mod; ntt(A,0); for(int i=0;i<nn;i++)ans[i]=(ans[i]+A[i]*invn*val%mod+mod)%mod; } void solve(int x){ vis[x]=1; cal(x,0,1); for(int i=head[x];i;i=Next[i]){ int y=ver[i]; if(vis[y])continue; cal(y,1,-1); sum=siz[y]; findroot(y,root=0); solve(root); } } int main(){ scanf("%d",&n); f[0]=mod; for(int i=1,x,y;i<n;i++){ scanf("%d%d",&x,&y); x++,y++; add(x,y),add(y,x); } invG=pw(G,mod-2); sum=n; findroot(1,0); solve(root); for(int i=0;i<n;i++){ ans1+=(1.0/(i+1))*ans[i]; } printf("%.4Lf",ans1); return 0; } //转化成每个点的贡献:每个点的贡献即为它在分治树上的深度 //考虑一个点对的贡献:点x在点y计数时产生1的贡献,说明点y是x到y这条路径上被选出来的第一个点。 //如果选了路径以外的点,对x和y的关系没有影响,它们在同一棵子树中。 //如果选了路径上其它点,则x和y会被分到两个不同子树中,且都对选出来的点产生1的贡献。 //x-y这条路径上每个点被选中的概率是相同的,所以(x,y)产生贡献的期望为1/(len(x,y)+1)(len+1即为路径上的点数) //统计所有长度的路径的数量即可,可以点分治+FFT在O(nlog^2n)的复杂度内求出 //注意确保每次NTT之前都把边界卡在当前子树深度范围处,保证总复杂度正确
Triple:
#include<iostream> #include<cstdio> using namespace std; const int N=150010,g=3; const long long mod=2281701377; int n,m,pos[N],cnt[N],maxx; long long F[N],G[N],H[N],invg,invn,ans[N],inv2,inv3; long long pw(long long x,long long k){ long long num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void ntt(long long *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; long long fir=pw((flag?g:invg),(mod-1)/p); for(int k=0;k<n;k+=p){ long long bur=1; for(int l=k;l<k+len;l++){ long long tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main() { scanf("%d",&n); for(int i=1,x;i<=n;i++){ scanf("%d",&x); F[x]++; ans[x]++; cnt[x]++; maxx=max(maxx,x); } for(m=maxx*3,n=1;n<=m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); } invg=pw(g,mod-2),invn=pw(n,mod-2),inv2=pw(2,mod-2); ntt(F,1); for(int i=0;i<n;i++){ H[i]=F[i]*F[i]%mod; G[i]=H[i]; } ntt(G,0); for(int i=0;i<=n;i++){ G[i]=G[i]*invn%mod; if(i%2==0&&cnt[i/2])G[i]=((G[i]-cnt[i/2])%mod+mod)%mod; G[i]=G[i]*inv2%mod; ans[i]+=G[i]; } ntt(G,1); for(int i=0;i<n;i++){ G[i]=G[i]*F[i]%mod; } ntt(G,0); for(int i=0;i<n;i++){ H[i]=H[i]*F[i]%mod; } ntt(H,0); for(int i=0;i<=n;i++){ H[i]=H[i]*invn%mod,G[i]=G[i]*invn%mod; if(i%3==0&&cnt[i/3])H[i]=((H[i]-cnt[i/3])%mod+mod)%mod; H[i]=H[i]*invg%mod; ans[i]+=G[i]-H[i]; } for(int i=0;i<n;i++){ if(ans[i]){ printf("%d %lld\n",i,ans[i]); } } return 0; } //F+ //(F*F(=H)-F(每个原数的平方项的系数减原数的数量))/2(=G)+ //G*F-((H*F-F(每个原数的立方项的系数减原数的数量))/3) //注意FFT和IDFT在函数中的差别 //注意除法用逆元 //注意计算使质数mod的范围大于结果的值 //正确式子: //先构造Ai为x指数的生成函数A(x) //再构造2Ai为指数的生成函数B(x) //再构造3Ai为指数的生成函数C(x) //A(x)+(A^2(x)-B(x))/2+(A^3(x)-3*A(x)*B(x)+2*C(x))/6
//两个多项式,A的项系数为同一位置是否为a,B则为是否为b(字符串预先拓展,中间加特殊字符# //A*A,B*B,每个位置的sum为两个多项式中i*2的项的系数相加 //每一项拓展成2^a[i]-1,各项相加(这就是全部的情况了) //减去不合法的数目——回文子串的数目 //manacher! #include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int N=6e5+10,G=3; const long long mod=2281701377,mod0=1e9+7; char s[N],s0[N]; int lens,m,n,pos[N]; ll A[N],B[N],invG,invn,ans[N],sum,inv2,inv; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } ll pw0(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod0; x=x*x%mod0; k>>=1; } return num; } void ntt(ll *f,bool flag){ for(int i=0;i<n;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=n;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ ll bur=1; for(int l=k;l<k+len;l++){ ll tt=f[len+l]*bur%mod; f[len+l]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } ll lon[N]; void manacher(){ int right=0,pos=0; lon[0]=1; s[0]='#'; s[m+1]='#'; for(int i=1;i<=m;i++){ if(right<i){ lon[i]=1; while(s[i-lon[i]]==s[i+lon[i]]&&i-lon[i]>=0&&i+lon[i]<=m+1)lon[i]++; right=i+lon[i]-1; pos=i; } else{ int j=2*pos-i; if(i+lon[j]-1>right)lon[i]=right-i+1; else if(i+lon[j]-1<right)lon[i]=lon[j]; else{ lon[i]=lon[j]; while(s[i-lon[i]]==s[i+lon[i]]&&i-lon[i]>=0&&i+lon[i]<=m+1)lon[i]++; right=i+lon[i]-1; pos=i; } } } } int main() { scanf("%s",s0+1); lens=strlen(s0+1); for(int i=lens;i>=1;i--){ s[i*2-1]=s0[i]; s[i*2-2]='#'; if(s0[i]=='a')A[i*2-1]=1; else B[i*2-1]=1; } for(m=lens*2-1,n=1;n<=2*m;n<<=1); for(int i=0;i<n;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); } invn=pw(n,mod-2),invG=pw(G,mod-2); ntt(A,1),ntt(B,1); for(int i=0;i<n;i++){ A[i]=A[i]*A[i]%mod; B[i]=B[i]*B[i]%mod; } ntt(A,0),ntt(B,0); for(int i=0;i<n;i++){ A[i]=A[i]*invn%mod; B[i]=B[i]*invn%mod; } for(int i=1;i<=m;i++){ if(s[i]=='a')A[i*2]=(A[i*2]+1)%mod; else if(s[i]=='b')B[i*2]=(B[i*2]+1)%mod; } inv=pw(2,mod-2); for(int i=0;i<n;i++){ A[i]=A[i]*inv%mod; B[i]=B[i]*inv%mod; } inv2=pw0(2,mod0-2); for(int i=0;i<=m;i++){ ans[i]=(A[i*2]+B[i*2])%mod0; ans[i]=((pw0(2,ans[i])-1)%mod0+mod0)%mod0; sum=(sum+ans[i])%mod0; } manacher(); for(int i=1;i<=m;i++){ if(lon[i]%2==1)lon[i]--; lon[i]=lon[i]*inv2%mod0; sum=((sum-lon[i])%mod0+mod0)%mod0; } printf("%lld\n",sum); return 0; }
#include<iostream> #include<cstdio> #include<cmath> #include<cstring> #define ll long long using namespace std; const int N=1e5+10,G=3; const long long mod=1004535809; ll n,m,x,s,g,rec[N]; ll pri[N],cnt,mm,nn,pos[N]; ll A[N],invn,invG,ans; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } ll pw0(ll x,ll k){//!! ll num=1; while(k){ if(k&1)num=num*x%m; x=x*x%m; k>>=1; } return num; } int check(ll x){ for(int i=1;i<=cnt;i++){ if(pw0(x,pri[i])==1)return 0; } return 1; } void ntt(ll *f,bool flag){ for(int i=0;i<nn;i++){ if(i<pos[i])swap(f[i],f[pos[i]]); } for(int p=2;p<=nn;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<nn;k+=p){ ll bur=1; for(int l=k;l<len+k;l++){ ll tt=bur*f[len+l]%mod; f[len+l]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } void ks(ll *A,ll k){ ll B[N]; memset(B,0,sizeof(B)); for(int i=0;i<nn;i++)B[i]=A[i],A[i]=0; A[0]=1; while(k){ if(k&1){ ntt(A,1),ntt(B,1); for(int i=0;i<nn;i++){ A[i]=A[i]*B[i]%mod; } ntt(A,0),ntt(B,0); for(int i=0;i<nn;i++){ A[i]=A[i]*invn%mod; B[i]=B[i]*invn%mod; if(i>m-1){ A[i-(m-1)]=(A[i-(m-1)]+A[i])%mod; A[i]=0; B[i-(m-1)]=(B[i-(m-1)]+B[i])%mod; B[i]=0; } } } ntt(B,1); for(int i=0;i<nn;i++){ B[i]=B[i]*B[i]%mod; } ntt(B,0); for(int i=0;i<nn;i++){ B[i]=B[i]*invn%mod; if(i>m-1){ B[i-(m-1)]=(B[i-(m-1)]+B[i])%mod; B[i]=0; } } k>>=1; } } int main() { scanf("%lld%lld%lld%lld",&n,&m,&x,&s); for(int i=2;i*i<=m-1;i++){ if((m-1)%i==0){ pri[++cnt]=i; if((i*i)!=(m-1))pri[++cnt]=(m-1)/i; } } // pri[++cnt]=m-1; for(int i=2;i<=100;i++){ if(check(i)){ g=i; ll num=1; for(int j=1;j<=m-1;j++){ num=num*g%m; rec[num]=j; } break; } } for(int i=1;i<=s;i++){ ll xx; scanf("%lld",&xx); if(xx==0)continue; A[rec[xx%m]]++; } for(mm=m*2,nn=1;nn<=mm;nn<<=1); for(int i=0;i<nn;i++){ pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0); } invn=pw(nn,mod-2),invG=pw(G,mod-2); ks(A,n); printf("%lld\n",A[rec[x]]); return 0; } //注意求原根的时候,快速幂取模不要和全局取模弄混 //这里多项式相乘10^9之后数组是没法直接存下那么多项的 //但是由于这题的特殊性质,第i项和第i-(m-1)项等价,于是可以把后面的加去前面
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int N=300010,G=3; const long long mod=998244353; int m,n,pos[N],mm; ll invn,invG,A[N],B[N],g[N],rec[N],inv[N],ans,sum; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void work(){ rec[0]=inv[0]=rec[1]=1; for(int i=2;i<=m;i++)rec[i]=rec[i-1]*i%mod; inv[m]=pw(rec[m],mod-2); for(int i=m-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod; } void ntt(ll *f,bool flag){ for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=n;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ ll bur=1; for(int l=k;l<len+k;l++){ ll tt=bur*f[l+len]%mod; f[len+l]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main(){ scanf("%d",&m); for(mm=2*m,n=1;n<=mm;n<<=1); invn=pw(n,mod-2),invG=pw(G,mod-2); work(); for(int i=0;i<=m;i++){ A[i]=(((i&1)?-1:1)+mod)%mod*inv[i]%mod; B[i]=(pw(i,m+1)-1)*pw(i-1,mod-2)%mod*inv[i]%mod; } B[0]=1,B[1]=m+1; for(int i=1;i<n;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(n>>1):0); ntt(A,1),ntt(B,1); for(int i=0;i<n;i++)A[i]=A[i]*B[i]%mod; ntt(A,0); for(int i=0;i<=m;i++)g[i]=rec[i]*A[i]%mod*invn%mod; sum=1; for(int i=0;i<=m;i++)ans=(ans+sum*g[i]%mod)%mod,sum=sum*2%mod; printf("%lld\n",ans); return 0; } //注意区分变量名 //推公式懒得打了,复习的时候不会的话(退役吧)看题解吧 //第二类斯特林数递推公式:S(i,j)=S(i-1,j-1)+j*S(i-1,j) //含义:表示i个不同球放在j个相同盒子里,盒子不允许为空的方案数。 //如果前面的球放在j-1个盒子里就新拿一个盒子 ,如果前面的球已经放了j个盒子就随便选一个放进去 //(相关:排列组合问题——8种情况的球和盒子) //第二类斯特林数容斥原理公式:S(i,j)=1/j!*Σ(-1)^k*C(j,k)*(j-k)^i,(0<=k<=j) //含义:先考虑盒子不同的问题,枚举至少k个盒子为空,有C(j,k)种选盒子的方式 //球随便放在剩下的盒子里的方案是(j-k)^i(这里仍然可能出现其它盒子为空,因为球随便放了) //用容斥来计算恰好0个盒子为空的方案数,容斥系数是(-1)^k //因为实际上盒子是相同的,所以除去盒子全排列的方案数 //第二类斯特林数的性质:j^i=ΣS(i,k)*C(j,k)*k!,(0<=k<=j) //含义:j^i即为i个不同球放在j个不同盒子里,盒子可以为空的方案数 //枚举有多少盒子不为空,此时有C(j,k)种选盒子的方式,存在S(i,k)代表球放在这么多相同盒子里的方案数 //由于盒子其实是不同的,所以乘上盒子全排列的方案数 //附加:i个不同球放在j个不同盒子里且盒子不允许为空的方案数即为上式去掉枚举k个盒子不为空
#include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=650010,G=3; const long long mod=1004535809; int n,m,s,limit,maxn,mm,nn,pos[N]; ll w[N],rec[10000010],inv[10000010],f[N],g[N],h[N],invn,invG,ans; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void work(){ maxn=max(m,n); rec[0]=inv[0]=rec[1]=1; for(int i=2;i<=maxn;i++)rec[i]=rec[i-1]*i%mod; inv[maxn]=pw(rec[maxn],mod-2); for(int i=maxn-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod; } void ntt(ll *f,bool flag){ for(int i=0;i<nn;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=nn;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<nn;k+=p){ ll bur=1; for(int l=k;l<k+len;l++){ ll tt=bur*f[l+len]%mod; f[l+len]=((f[l]-tt)%mod+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } int main(){ scanf("%d%d%d",&n,&m,&s); for(int i=0;i<=m;i++){ scanf("%lld",&w[i]); } work(); limit=min(n/s,m); for(int i=0;i<=limit;i++){ f[i]=rec[m]*inv[m-i]%mod*rec[n]%mod*inv[n-i*s]%mod; f[i]=f[i]*pw(inv[s],i)%mod*pw(m-i,n-i*s)%mod; } for(mm=2*limit,nn=1;nn<=mm;nn<<=1); for(int i=0;i<=limit;i++)g[i]=(((i&1)?-1:1)*inv[i]%mod+mod)%mod; for(int i=0;i<=limit;i++)h[limit-i]=g[i]; for(int i=1;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0); invn=pw(nn,mod-2),invG=pw(G,mod-2); ntt(f,1),ntt(h,1); for(int i=0;i<nn;i++)f[i]=f[i]*h[i]%mod; ntt(f,0); for(int i=0;i<=limit;i++)ans=(ans+f[i+limit]*invn%mod*inv[i]%mod*w[i]%mod)%mod; printf("%lld",ans); return 0; } //先求至少i种染了s次的方案数f[i] //f[i]=C(m,i)*C(n,i*s)*(i*s)!/(s!)^i*(m-i)^(n-i*s) //含义是,在m种颜色里选择i种,在n个位置里占了哪s*i个,多重集的排列数,剩下的n-i*s个随便填剩下的颜色 //然后求恰好i种染了s次的方案数g[i],用容斥处理 //g[i]=Σ(-1)^(j-i)*C(j,i)*f[j] (j>=i) //含义是,容斥系数,加上至少i个,减去至少i+1个…每个f[j]会被多算C(j,i)次 //把组合数拆开,化简得 //g[i]=(Σ(-1)^(j-i)/(j-i)!*f[j]*j!)/i! //常见套路,设A[i]=(-1)^i/i!,B[i]=f[i]*i!,反转A数组,对A*B进行NTT,在i+n(设反转总长为n)处寻找i的答案
#include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=500010,G=3; const long long mod=1004535809; int n,pos[N],nn; ll inv[N],rec[N],invG,B[N],C[N],D[N],c[N],invnn; ll pw(ll x,ll k){ ll num=1; while(k){ if(k&1)num=num*x%mod; x=x*x%mod; k>>=1; } return num; } void work(){ inv[0]=rec[0]=rec[1]=1; for(int i=2;i<=n;i++)rec[i]=rec[i-1]*i%mod; inv[n]=pw(rec[n],mod-2); for(int i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%mod; } void ntt(ll *f,int n,bool flag){ for(int i=0;i<n;i++)if(i<pos[i])swap(f[i],f[pos[i]]); for(int p=2;p<=n;p<<=1){ int len=p>>1; ll fir=pw((flag?G:invG),(mod-1)/p); for(int k=0;k<n;k+=p){ ll bur=1; for(int l=k;l<k+len;l++){ ll tt=f[l+len]*bur%mod; f[l+len]=(f[l]-tt+mod)%mod; f[l]=(f[l]+tt)%mod; bur=bur*fir%mod; } } } } void getinv(int now,ll *a,ll *b){ if(now==1){b[0]=pw(a[0],mod-2);return;} getinv((now+1)>>1,a,b); int goal=1; while(goal<(now<<1))goal<<=1; ll invn=pw(goal,mod-2); for(int i=0;i<goal;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(goal>>1):0); for(int i=0;i<now;i++)c[i]=a[i]; for(int i=now;i<goal;i++)c[i]=0; ntt(b,goal,1),ntt(c,goal,1); for(int i=0;i<goal;i++)b[i]=((2ll-b[i]*c[i]%mod)+mod)%mod*b[i]%mod; ntt(b,goal,0); for(int i=0;i<now;i++)b[i]=b[i]*invn%mod; for(int i=now;i<goal;i++)b[i]=0; } int main(){ scanf("%d",&n); invG=pw(G,mod-2); work(); for(int i=0;i<=n;i++)B[i]=pw(2,1ll*i*(i-1)/2)*inv[i]%mod; for(int i=0;i<=n;i++)C[i]=pw(2,1ll*i*(i-1)/2)*inv[i-1]%mod; getinv(n,B,D); for(nn=1;nn<=(n*2);nn<<=1); for(int i=0;i<nn;i++)pos[i]=(pos[i>>1]>>1)|((i&1)?(nn>>1):0); ntt(C,nn,1),ntt(D,nn,1); for(int i=0;i<nn;i++)C[i]=C[i]*D[i]%mod; ntt(C,nn,0); invnn=pw(nn,mod-2); printf("%lld\n",C[n]*rec[n-1]%mod*invnn%mod); return 0; } //考虑用总图数减去不连通的图的数量 //总图数即2^C(n,2),任选两个点即为一条边,每条边有选或不选两种选择 //求不连通的图的数量:枚举现在已经可以确定的联通的图的大小,其它点随意安排。 //钦定1号点一直在联通的图中,因为包含1号点的连通块的大小一直在变化,所以枚举出来的所有子情况不重复 //于是设f[i]为i个点满足题目条件的方案数: //f[i]=2^C(i,2)-ΣC(n-1,j-1)* f[j]*2^C(i-j,2),(1<=j<=i-1) //转化式子,拆开不是指数的组合数,先两边同除以(i-1)!,移项,设0!=1来把枚举变成从1到i //然后把单项式分类成未知数形式相似的几部分,即卷积经典形式 //这时可以看出,设A[i]=f[i]/(i-1)!,B[i]=2^(i,2)/i!,C[i]= 2^(i,2)/(i-1)! //A*B=C,题目要求A //A=C*B^-1,对B多项式求逆元即可
持续补完。
对自己的记性没有太大指望。