【BZOJ4589】Hard Nim(FWT)
题解:
由博弈论可以知道题目等价于求这$n$个数$\^$为0
快速幂$+fwt$
这样是$nlog^2$的 并不能过
而且得注意$m$的数组$\^$一下会生成$2m$
#include <bits/stdc++.h> using namespace std; #define rint register int #define IL inline #define rep(i,h,t) for(int i=h;i<=t;i++) #define dep(i,t,h) for(int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define mep(x,y) memcpy(x,y,sizeof(y)) #define mid (t<=0?(h+t-1)/2:(h+t)/2) namespace IO{ char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T> void read(T &x) { rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; } char sr[1<<24],z[20]; ll Z,C1=-1; template<class T>void wer(T x) { if (x<0) sr[++C1]='-',x=-x; while (z[++Z]=x%10+48,x/=10); while (sr[++C1]=z[Z],--Z); } IL void wer1() { sr[++C1]=' '; } IL void wer2() { sr[++C1]='\n'; } template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;} template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} template<class T>IL T MAX(T x,T y){return x>y?x:y;} template<class T>IL T MIN(T x,T y){return x<y?x:y;} }; using namespace IO; const int N=4e5+10; const int M=5e4+10; const int inv2=5e8+4; const int mo=1e9+7; int n; void fwt_or(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) if (o==1) (a[i+j+k]+=a[j+k])%=mo; else (a[i+j+k]-=a[j+k])%=mo; } void fwt_and(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) if (o==1) (a[j+k]+=a[i+j+k])%=mo; else (a[j+k]-=a[i+j+k])%=mo; } void fwt_xor(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) { int x=a[j+k],y=a[i+j+k]; a[j+k]=(x+y)%mo,a[i+j+k]=(x+mo-y)%mo; if (o==-1) a[j+k]=1ll*a[j+k]*inv2%mo,a[i+j+k]=1ll*a[i+j+k]*inv2%mo; } } int p[N],now[N],cnt,a[N],ans[N],b[N],m; bool t[N]; void fwt(int *A,int *B) { me(a); me(b); rep(i,0,m) a[i]=A[i],b[i]=B[i]; fwt_xor(a,1); fwt_xor(b,1); rep(i,0,n) a[i]=1ll*a[i]*b[i]%mo; fwt_xor(a,-1); rep(i,0,m) B[i]=a[i]; } /*void fwt(int *A,int *B) { int C[N]={}; rep(i,0,n) rep(j,0,n) { int k=1; C[i^j]=(C[i^j]+A[i]*B[j])%mo; } rep(i,0,n) B[i]=C[i]; }*/ int fsp(int x) { mep(now,a); me(ans); ans[0]=1; while (x) { if (x&1) fwt(now,ans); fwt(now,now); x>>=1; } return ans[0]; } int main() { freopen("1.in","r",stdin); freopen("1.out","w",stdout); t[1]=1; rep(i,2,M) { if (!t[i]) p[++cnt]=i; for (int j=1;j<=cnt&&p[j]*i<=M;j++) { t[p[j]*i]=1; if (i%p[j]==0) break; } } int n1; while (cin>>n1>>m) { me(a); rep(i,1,cnt) { if (p[i]>m) break; a[p[i]]=1; } m*=2; for (n=1;n<=m;n<<=1); cout<<fsp(n1)<<endl; } return 0; }
可以一次$fwt$之后最后再$ifwt$回去
复杂度$nlogn$
#include <bits/stdc++.h> using namespace std; #define rint register int #define IL inline #define rep(i,h,t) for(int i=h;i<=t;i++) #define dep(i,t,h) for(int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define mep(x,y) memcpy(x,y,sizeof(y)) #define mid (t<=0?(h+t-1)/2:(h+t)/2) namespace IO{ char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T> void read(T &x) { rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while (c=gc(),c>47&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; } char sr[1<<24],z[20]; ll Z,C1=-1; template<class T>void wer(T x) { if (x<0) sr[++C1]='-',x=-x; while (z[++Z]=x%10+48,x/=10); while (sr[++C1]=z[Z],--Z); } IL void wer1() { sr[++C1]=' '; } IL void wer2() { sr[++C1]='\n'; } template<class T>IL void maxa(T &x,T y) {if (x<y) x=y;} template<class T>IL void mina(T &x,T y) {if (x>y) x=y;} template<class T>IL T MAX(T x,T y){return x>y?x:y;} template<class T>IL T MIN(T x,T y){return x<y?x:y;} }; using namespace IO; const int N=4e5+10; const int M=5e4+10; const int inv2=5e8+4; const int mo=1e9+7; int n; void fwt_or(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) if (o==1) (a[i+j+k]+=a[j+k])%=mo; else (a[i+j+k]-=a[j+k])%=mo; } void fwt_and(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) if (o==1) (a[j+k]+=a[i+j+k])%=mo; else (a[j+k]-=a[i+j+k])%=mo; } void fwt_xor(int *a,int o) { for (int i=1;i<n;i*=2) for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) { int x=a[j+k],y=a[i+j+k]; a[j+k]=(x+y)%mo,a[i+j+k]=(x+mo-y)%mo; if (o==-1) a[j+k]=1ll*a[j+k]*inv2%mo,a[i+j+k]=1ll*a[i+j+k]*inv2%mo; } } int p[N],now[N],cnt,a[N],ans[N],b[N],m; bool t[N]; void fwt(int *A,int *B) { me(a); me(b); rep(i,0,m) a[i]=A[i],b[i]=B[i]; fwt_xor(a,1); fwt_xor(b,1); rep(i,0,n) a[i]=1ll*a[i]*b[i]%mo; fwt_xor(a,-1); rep(i,0,m) B[i]=a[i]; } /*void fwt(int *A,int *B) { int C[N]={}; rep(i,0,n) rep(j,0,n) { int k=1; C[i^j]=(C[i^j]+A[i]*B[j])%mo; } rep(i,0,n) B[i]=C[i]; }*/ int fsp(int x) { mep(now,a); me(ans); ans[0]=1; fwt_xor(ans,1); fwt_xor(now,1); while (x) { if (x&1) rep(i,0,n) ans[i]=1ll*ans[i]*now[i]%mo; rep(i,0,n) now[i]=1ll*now[i]*now[i]%mo; x>>=1; } fwt_xor(ans,-1); return ans[0]; } int main() { freopen("1.in","r",stdin); freopen("1.out","w",stdout); t[1]=1; rep(i,2,M) { if (!t[i]) p[++cnt]=i; for (int j=1;j<=cnt&&p[j]*i<=M;j++) { t[p[j]*i]=1; if (i%p[j]==0) break; } } int n1; while (cin>>n1>>m) { me(a); rep(i,1,cnt) { if (p[i]>m) break; a[p[i]]=1; } m*=2; for (n=1;n<=m;n<<=1); cout<<fsp(n1)<<endl; } return 0; }