[CF286E] Ladies' shop
Description
给出 \(n\) 个 \(\leq m\) 且不同的数 \(a_1,\dots,a_n\),现在要求从这 \(n\) 个数中选出最少的数字,满足这 \(n\) 个数字都可以由选出的数字组合成(就是做一个完全背包能做出来),并且任意组合出来的数字,只要不超过 \(m\),就必须让这个数字在给出的 \(n\) 个数中。问是否可行,如果可行,请求出最少选多少数字。 \(n,m\leq 10^6\)。
Sol
先判断是否可行,再看哪些数可以省略。
求出 \(a\) 数组的生成函数,即构造多项式 \(F(x)=\sum f_i\cdot x^i\)。\(f_i\) 为 \(1\) 当且仅当 \(a\) 数组中出现 \(a_*=i\)
然后求出 \(G(x)=F^2(i)=\sum g_i\cdot x^i\)。如果 \(g_i>0\) 那就说明给出的这 \(n\) 个数可以合成 \(i\) 。
于是就得到了从原来的 \(n\) 个数中拿出 \(0\sim 2\) 个的结果。
然而最多拿出 \(m\) 个。
所以还要继续,用快速幂求得 \(f^m\)。如果多项式快速幂的话,复杂度 \(O(n\log^2n)\),用多项式ln+多项式exp求的话,复杂度 \(O(n\log n)\)。但是多项式exp常数太大了!
事实上是有只做 \(1\) 次FFT的方法的。
显然如果 \(f_i>0\) 的话,\(g_i>0\)。
那我们只要保证满足 \(f_i=0,g_i>0,i\leq m\) 的 \(i\) 不存在就好了。
如果第一轮不存在这些不合法的,那接下来肯定也不存在。感性理解一下这就相当于构成了一个封闭的集合。
所以只做 \(1\) 次FFT就行了。
然后考虑一下哪些数可以省略
如果一个数 \(i\) 可以被其他数表示出来,那 \(g_i\) 一定 \(>2\)。所以 \(g_i=2\) 的 \(i\) 就是必选的。
时间复杂度 \(O(n\log n)\)。
Sol
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e6+5;
const int mod=998244353;
int lim,rev[N];
int n,m,a[N],b[N];
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
void ntt(int *f,int g){
for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int tmp=ksm(g,(mod-1)/(mid<<1));
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
}
}
} if(g>3)
for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}
signed main(){
n=getint(),m=getint();
for(int i=1;i<=n;i++){
int x=getint();
a[x]=b[x]=1;
}
lim=1;while(lim<=m+m) lim<<=1;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
a[0]=1; ntt(a,3);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*a[i]%mod;
ntt(a,(mod+1)/3);
for(int i=1;i<=m;i++)
if(a[i] and !b[i]) return printf("NO"),0;
puts("YES"); int tot=0;
for(int i=1;i<=m;i++)
if(a[i]==2) tot++;
printf("%d\n",tot);
for(int i=1;i<=m;i++)
if(a[i]==2) printf("%d ",i);
return 0;
}