洛谷 3784(bzoj 4913) [SDOI2017]遗忘的集合——多项式求ln+MTT
题目:https://www.luogu.org/problemnew/show/P3784
https://www.lydsy.com/JudgeOnline/problem.php?id=4913
和洛谷3489“付公主的背包”一样的套路。
要设 a[ i ] 表示第 i 个值有没有出现。
然后就有 \( \prod\limits_i(\frac{1}{1-x^i})^{a_i} = f(x) \)
因为有 \( \prod \) ,所以两边取 ln 。
\( \sum\limits_{i}a_{i}ln(\frac{1}{1-x^i}) = ln(f(x)) \)
现在想求一个 \( ln(\frac{1}{1-x^i}) \) 的更优美的形式(一般是形如 \( \sum \) 的),来更简单地刻画 a[ i ] 和 f[ i ] 的关系。(f[ i ] 是 ln( f(x) ) 的第 i 项系数)
因为有 \( ln \) ,所以先求导再积分来化式子。
并且 \( \frac{f'(x)}{f(x)} \) 了之后,把 \( f'(x) \) 写成 \( \sum \) 的形式,用 \( f(x) \) 和 \( \int \) 化出一个更好看的 \( \sum \) 的式子。
\( \int (1-x^i)\sum\limits_{j=1}i*j*x^{i*j-1} \) // j 从 1 开始
\( = \int \sum\limits_{j=1}i*j*x^{i*j-1} - \sum\limits_{j=1}i*j*x^{i*(j+1)-1} \)
\( = \int \sum\limits_{j=1}i*x^{i*j-1} \)
\( = \sum\limits_{j=0}\frac{1}{j}*x^{i*j} \)
所以 \( \sum\limits_{i=1}a_i\sum\limits_{j=0}\frac{1}{j}x^{i*j} = ln(f(x)) \)
\( \sum\limits_{i=1}\sum\limits_{j=0}a_i*\frac{1}{j} = f[i*j] \)
\( f[i]=\sum\limits_{j\|i}a_j*\frac{j}{i} \)
把分母的 i 乘到左边,然后莫比乌斯反演一下就知道 \( a_i *i= \sum\limits_{j\|i}f[j]*j*u(i/j) \)
实现的时候要写 MTT 。写拆系数 FFT 的话需要 long double 。自己写的三模数 NTT 还没调出来,不知是哪里出错。
有许许多多的细节需要注意。
#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define db long double #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=(1<<19)+5; int n,p,f[N],g[N],u[N],pri[N]; bool vis[N]; int upt(int x){if(x>=p)x-=p;if(x<0)x+=p;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%p;x=(ll)x*x%p;k>>=1;}return ret;} namespace poly{ const db pi=acos(-1); struct cpl{ db x,y; cpl(db x=0,db y=0):x(x),y(y) {} cpl operator+ (const cpl &b)const {return cpl(x+b.x,y+b.y);} cpl operator- (const cpl &b)const {return cpl(x-b.x,y-b.y);} cpl operator* (const cpl &b)const {return cpl(x*b.x-y*b.y,x*b.y+y*b.x);} cpl operator/ (const int &b)const {return cpl(x/b,y/b);} }; cpl conj(cpl a){return cpl(a.x,-a.y);} int len,r[N],inv[N]; cpl Wn[N]; int bs,pbs,bs2; cpl pa[N],pb[N],pc[N],pd[N]; int A[N],B[N],tp[N]; void init() { int tmp=sqrt(p); for(bs=0,pbs=1;pbs<=tmp;bs++,pbs<<=1); bs2=bs<<1; pbs--; } void fft_pre() { for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); for(int R=2,m=1;R<=len;m=R,R<<=1) Wn[R]=cpl( cos(pi/m),sin(pi/m) ); } void fft(cpl *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { cpl wn=fx?conj(Wn[R]):Wn[R]; for(int i=0,m=R>>1;i<len;i+=R) { cpl w=cpl(1,0); for(int j=0;j<m;j++,w=w*wn) { cpl x=a[i+j], y=w*a[i+m+j]; a[i+j]=x+y; a[i+m+j]=x-y; } } } if(!fx)return; for(int i=0;i<len;i++)a[i]=a[i]/len; } void mtt(int n1,int *a,int n2,int *b,int *c) { int n3=n1+n2-1; for(len=1;len<n3;len<<=1); fft_pre(); //for(int i=0;i<n1;i++) pa[i]=cpl(a[i]>>15,a[i]&32767); //for(int i=0;i<n2;i++) pb[i]=cpl(b[i]>>15,b[i]&32767); for(int i=0;i<n1;i++) pa[i]=cpl(a[i]>>bs,a[i]&pbs); for(int i=0;i<n2;i++) pb[i]=cpl(b[i]>>bs,b[i]&pbs); for(int i=n1;i<len;i++) pa[i]=cpl(0,0); for(int i=n2;i<len;i++) pb[i]=cpl(0,0); fft(pa,0); fft(pb,0); pa[len]=pa[0]; pb[len]=pb[0]; for(int i=0,j=len;i<len;i++,j--)//q[i]=conj(p[j]) { cpl ta=(pa[i]+conj(pa[j]))*cpl(0.5,0);//conj(*[j])!! cpl tb=(pa[i]-conj(pa[j]))*cpl(0,-0.5); cpl tc=(pb[i]+conj(pb[j]))*cpl(0.5,0); cpl td=(pb[i]-conj(pb[j]))*cpl(0,-0.5); pc[i]=ta*tc+ta*td*cpl(0,1); pd[i]=tb*tc+tb*td*cpl(0,1); } pa[0]=pb[0]=cpl(0,0); fft(pc,1); fft(pd,1); for(int i=0;i<n3;i++) { ll ta=(ll)(pc[i].x+0.5)%p; ll tb=(ll)(pc[i].y+0.5)%p; ll tc=(ll)(pd[i].x+0.5)%p; ll td=(ll)(pd[i].y+0.5)%p; c[i]=((ta<<bs2)+((tb+tc)<<bs)+td)%p; //c[i]=((ta<<30)+((tb+tc)<<15)+td)%p; } } void get_dao(int n,int *a,int *b) { for(int i=1;i<n;i++)b[i-1]=(ll)a[i]*i%p; b[n-1]=0; } void get_jf(int n,int *a,int *b) { inv[1]=1; for(int i=2;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//(p-..)! for(int i=n-1;i;i--)b[i]=(ll)a[i-1]*inv[i]%p;//i-- for a==b b[0]=0; } void get_inv(int n,int *a,int *b) { b[0]=pw(a[0],p-2); for(int l=2;l<=n;l<<=1) { for(int i=l>>1;i<l;i++)b[i]=0;///// mtt(l,a,l,b,tp); mtt(l,b,l,tp,tp);/////b*tp not a*tp for(int i=0;i<l;i++) b[i]=((ll)b[i]*2-tp[i]+p)%p; } } void get_ln(int n,int *a,int *b) { get_dao(n,a,A); get_inv(n,a,B); mtt(n,A,n,B,A); get_jf(n,A,b); } } void get_mu(int n) { int cnt=0; u[1]=1; for(int i=2,d;i<=n;i++) { if(!vis[i])pri[++cnt]=i,u[i]=-1; for(int j=1;j<=cnt&&(d=i*pri[j])<=n;j++) { vis[d]=1; u[d]=-u[i]; if(i%pri[j]==0){u[d]=0; break;} } } } int main() { n=rdn();p=rdn(); poly::init();// for(int i=1;i<=n;i++)f[i]=rdn(); f[0]=1;//f[0]=1 int l=1;for(;l<=n;l<<=1);//<=n poly::get_ln(l,f,f); get_mu(n); for(int i=1;i<=n;i++)f[i]=(ll)f[i]*i%p; for(int i=1;i<=n;i++) for(int j=1,k=i;k<=n;j++,k+=i) g[k]=upt(g[k]+f[i]*u[j]); int cnt=0; for(int i=1;i<=n;i++)if(g[i])cnt++; printf("%d\n",cnt); for(int i=1;i<=n;i++)if(g[i])printf("%d ",g[i]); puts(""); return 0; }
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int upt(int x,int mod) {while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k,int mod) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} const int N=(1<<19)+5; int p; namespace poly{ const double eps=1e-6; int m[3]={998244353,1004535809,469762049}; ll M=(ll)m[0]*m[1], A[N],B[N],C[3][N]; int len,r[N],Wn[N][2],inv[N]; int tp[N],ta[N],tb[N]; ll mul(ll a,ll b,ll mod) { a=(a%mod+mod)%mod; b=(b%mod+mod)%mod;///// ll ret=(a*b- (ll)((long double)a/mod*b+eps) *mod)%mod; if(ret<0)ret+=mod; return ret; } void ntt_pre(int len,int mod) { for(int R=2;R<=len;R<<=1) Wn[R][0]=pw( 3,(mod-1)/R,mod ), Wn[R][1]=pw( 3,(mod-1)-(mod-1)/R,mod ); } void ntt(ll *a,bool fx,int mod) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int wn=Wn[R][fx]; for(int i=0,m=R>>1;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod) { int x=a[i+j], y=(ll)w*a[i+m+j]%mod; a[i+j]=upt(x+y,mod); a[i+m+j]=upt(x-y,mod); } } if(!fx)return; int inv=pw(len,mod-2,mod); for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod; } void mtt(int n,int *a,int n2,int *b,int *c)//ok if c==a||c==b { for(len=1;len<n+n2;len<<=1); int mod; for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); for(int i=0;i<3;i++) { mod=m[i]; for(int j=0;j<n;j++)A[j]=a[j]; for(int j=n;j<len;j++)A[j]=0; for(int j=0;j<n2;j++)B[j]=b[j]; for(int j=n2;j<len;j++)B[j]=0; ntt_pre(len,mod); ntt(A,0,mod); ntt(B,0,mod); for(int j=0;j<len;j++)C[i][j]=(ll)A[j]*B[j]%mod; ntt(C[i],1,mod); } len=n+n2-1;//n-1 + m-1 = n+m-2 mod=m[1]; int tm=m[0],inv=pw(tm,mod-2,mod); for(int i=0;i<len;i++) { int tmp=(ll)upt(C[1][i]-C[0][i],mod)*inv%mod; c[i]=((ll)tmp*tm+C[0][i])%M; } mod=p; tm=m[2]; inv=pw(M%tm,tm-2,tm); for(int i=0;i<len;i++) { int tmp=mul((C[2][i]-c[i])%tm+tm,inv,tm); c[i]=(mul(tmp,M,mod)+c[i])%mod; } } void get_dao(int n,int *a,int *b) { for(int i=1;i<n;i++)b[i-1]=(ll)a[i]*i%p; b[n-1]=0; } void get_jf(int n,int *a,int *b) { inv[1]=1; for(int i=2;i<n;i++)inv[i]=(ll)(p-p/i)*inv[p%i]%p;//p/i for(int i=n-1;i;i--)b[i]=(ll)a[i-1]*inv[i]%p;//i--:a==b b[0]=0; } void get_inv(int n,int *a,int *b)//tb[] { b[0]=pw(a[0],p-2,p); for(int l=2,tn=1;tn<n;tn=l,l<<=1) { for(int i=tn;i<l;i++)b[i]=0; mtt(l,a,l,b,tb); mtt(l,b,l,tb,tb); for(int i=0;i<l;i++) b[i]=((ll)b[i]*2-tb[i]+p)%p; } } void get_ln(int n,int *a,int *b)//ta[],tp[]//ok if b==a {//%x^n get_dao(n,a,ta); get_inv(n,a,tp); mtt(n,ta,n,tp,ta); get_jf(n,ta,b); } } int n,f[N],ans[N],mu[N],pri[N]; bool vis[N]; void get_mu(int n) { mu[1]=1; int cnt=0; for(int i=2;i<=n;i++) { if(!vis[i])pri[++cnt]=i,mu[i]=-1; for(int j=1,d;j<=cnt&&(d=i*pri[j])<=n;j++) { vis[d]=1; if(i%pri[j]==0){mu[d]=0;break;} mu[d]=-mu[i]; } } } int main() { n=rdn();p=rdn(); for(int i=1;i<=n;i++)f[i]=rdn(); f[0]=1;//f[0]=1 poly::get_ln(n+1,f,f);//n+1 for(int i=1;i<=n;i++)f[i]=(ll)f[i]*i%p; get_mu(n); for(int i=1;i<=n;i++) for(int j=1,k=i;k<=n;j++,k+=i) ans[k]=upt(ans[k]+mu[j]*f[i],p); int cnt=0; for(int i=1;i<=n;i++)if(ans[i])cnt++; printf("%d\n",cnt); for(int i=1;i<=n;i++)if(ans[i])printf("%d ",ans[i]); puts(""); return 0; }