CF632E: Thief in a Shop(快速幂+NTT)(存疑)
A thief made his way to a shop.
As usual he has his lucky knapsack with him. The knapsack can contain k objects. There are n kinds of products in the shop and an infinite number of products of each kind. The cost of one product of kind i is ai.
The thief is greedy, so he will take exactly k products (it's possible for some kinds to take several products of that kind).
Find all the possible total costs of products the thief can nick into his knapsack.
Input
The first line contains two integers n and k (1 ≤ n, k ≤ 1000) — the number of kinds of products and the number of products the thief will take.
The second line contains n integers ai (1 ≤ ai ≤ 1000) — the costs of products for kinds from 1 to n.
Output
Print the only line with all the possible total costs of stolen products, separated by a space. The numbers should be printed in the ascending order.
Examples
3 2
1 2 3
2 3 4 5 6
5 5
1 1 1 1 1
5
3 3
3 5 11
9 11 13 15 17 19 21 25 27 33
题意:给定N个数a[],让你选择K个数,可以重复选,求其组合成的和有哪些。N、K、a[]<=1000;
思路:看成1000000项的多项式,如果存在a[]=x,则x的系数为1,然后多项式自乘K次,系数不为0的部分表示可以有K个数构成,可以用FFT+快速幂,为了避免精度误差,每次快速幂后把非0的改为1,免得变得很大后产生误差,复杂度O(1000000*log1000000*logK),有点大,稍微优化下常数可以卡过。
这里尝试 用NTT,由于系数可以达到1000^1000,所以需要除Mod,但是避免除一个Mod恰好变为0,所以我们取两个Mod避免hack。
快速幂+NTT 4398ms:
#include<bits/stdc++.h> #define rep(i,x,y) for(int i=x;i<=y;i++) using namespace std; #define MOD Mod #define ll long long const int G=3; const int maxn=5685760; int Mod; int qpow(int v,int p) { int ans=1; for(;p;p>>=1,v=1ll*v*v%Mod) if(p&1)ans=1ll*ans*v%Mod; return ans; } void rader(int y[], int len) { for(int i=1,j=len/2;i<len-1;i++) { if(i<j) swap(y[i],y[j]); int k=len/2; while(j>=k) j-=k,k/=2; if(j<k) j+=k; } } void NTT(int y[],int len,int opt) { rader(y,len); for(int h=2;h<=len;h<<=1) { int wn=qpow(G,(MOD-1)/h); if(opt==-1) wn=qpow(wn,Mod-2); for(int j=0;j<len;j+=h) { int w=1; for(int k=j;k<j+h/2;k++) { int u=y[k]; int t=(ll)w*y[k+h/2]%MOD; y[k]=(u+t)%MOD; y[k+h/2]=(u-t+MOD)%MOD; w=(ll)w*wn%MOD; } } } if(opt==-1) { int t=qpow(len,MOD-2); for(int i=0;i<len;i++) y[i]=(ll)y[i]*t%MOD; } } void powNTT(int ans[],int a[],int x) { ans[0]=1;int len=1024; while(x){ len<<=1; if(x&1){ NTT(ans,len,1); NTT(a,len,1); rep(i,0,len-1) ans[i]=(ll)ans[i]*a[i]%Mod; NTT(ans,len,-1); NTT(a,len,-1); } NTT(a,len,1); rep(i,0,len-1) a[i]=(ll)a[i]*a[i]%Mod; NTT(a,len,-1); x>>=1; } } int A[maxn],B[maxn],ans1[maxn],ans2[maxn]; int main() { int N,K,x; scanf("%d%d",&N,&K); rep(i,1,N) scanf("%d",&x),A[x]=1,B[x]=1; Mod=998244353; powNTT(ans1,A,K); Mod=1004535809; powNTT(ans2,B,K); rep(i,1,1000000) if(ans1[i]||ans2[i]) printf("%d ",i); return 0; }
洛谷给出的代码,https://www.luogu.org/problemnew/solution/CF632E ,只一次NTT,在DFT后把每个数单自求pow(K),就得到了正确答案。
(暂时不理解其解法的正确性,如果是正确的,其NTT的写法里可能也有玄机(因为把这个NTT板子套其他题,样例过不了),尚待解决。
#include<bits/stdc++.h> #define rep(i,a,b) for(int i=a;i<=b;i++) using namespace std; #define ll long long const int G=3; const int maxn=1048576; int mod,n,k,rev[maxn],lim,ilim,s,wn[maxn+1]; std::vector<int> v; inline int pow(int x, int y) { int ans=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) ans=(ll)ans*x%mod; return ans; } inline int& up(int& x, int y) { if ((x+=y)>=mod) x-=mod; return x; } inline void NTT(int* A, int typ) { rep(i,0,lim-1) if (i<rev[i]) swap(A[i], A[rev[i]]); for (int i=1;i<lim;i+=i) { const int t=lim/i/2; for (int j=0;j<lim;j+=i+i) { for (int k=0;k<i; k++) { int w=typ?wn[t*k]:wn[lim-t*k]; int x=A[k+j],y=(ll)w*A[k+j+i]%mod; up(A[k+j],y),up(A[k+j+i]=x,mod-y); } } } if (!typ) rep(i,0,lim-1) A[i]=(ll)ilim*A[i]%mod; } inline void init(int len,int tmod) { mod=tmod; lim=1; s=-1; while(lim<len) lim+=lim,s++; ilim=pow(lim,mod-2); rep(i,0,lim-1) rev[i]=rev[i>>1]>>1|(i&1)<<s; int w=pow(G,(mod-1)/len); wn[0]=1; rep(i,1,lim) wn[i]=(ll)(wn[i-1])*w%mod; } int A[maxn], B[maxn]; int main() { scanf("%d%d",&n,&k); int x; rep(i,1,n) scanf("%d",&x), A[x]=B[x]=1; init(1048576, 998244353); NTT(A, 1); rep(i,0,lim-1) A[i]=pow(A[i],k); NTT(A, 0); rep(i,1,1000000) if (A[i]) v.push_back(i); init(1048576, 1004535809); NTT(B, 1); for (int i = 0; i < lim; i++) B[i] = pow(B[i], k); NTT(B, 0); rep(i,1,1000000) if (B[i]) v.push_back(i); sort(v.begin(), v.end()); int tot=unique(v.begin(), v.end())-v.begin(); v.resize(tot); for (int i : v) printf("%d ",i); return 0; }