【SDOI2017】遗忘的集合

题目描述

好神仙啊,我还真的以为这是个构造题,结果是有唯一解的。

设答案为多项式\(a,a_i\in\{0,1\}\)

则:

\[f(x)=\Pi (\frac{1}{1-x^i})^{a_i} \]

两边取对数:

\[\begin{align} ln(f(x))&=\sum a_i ln(\frac{1}{1-x^i}) \\&=-\sum a_iln(1-x^i) \end{align} \]

我们在对\(ln(x)\)\(x=1\)处进行泰勒展开。

由:

\[\begin{align} ln(x)&=\sum_{i=0}^{\infty}\frac{ln^{[i]}(x_0)}{i!}(x-x_0)^i\\ &=\sum_{i=1}^{\infty}(-1)^{i-1}\frac{1}{i}(x-1)^i \end{align} \]

得到:

\[ln(1-x^j)=\sum_{i=1}^{\infty}(-1)\frac{x^{ij}}{i} \]

所以:

\[\begin{align} ln(f(x))&=-\sum a_i \sum_{j=1}^{\infty}(-1)\frac{x^{ij}}{j} \\&=\sum_{i=1}^n x^i\sum_{j|i}a_j\frac{j}{i} \\&=\sum_{i=1}^n\frac{x^i}{i}\sum_{j|i}a_jj \end{align} \]

求出\(ln(f(x))\)后就可以用\(nlogn\)的复杂度求出\(a\)了。

因为是任意模数,所以要写\(MTT\)

推导很自然,思路很巧妙啊。关键是要想到列出关于答案数组\(a\)的等式再去将\(a\)解出来。

代码:

#include<bits/stdc++.h>


using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

typedef long long ll;
const int N=(1<<19);
const ll sqr=1<<15;
ll mod;
ll a[N<<2];
struct Com {
	double r,v;
	Com() {r=0,v=0;}
	Com(double a,double b) {r=a,v=b;}
};
Com operator *(const Com &a,const Com &b) {return Com(a.r*b.r-a.v*b.v,a.r*b.v+a.v*b.r);}
Com operator /(const Com &a,const double &y) {return Com(a.r/y,a.v/y);}
Com operator +(const Com &a,const Com &b) {return Com(a.r+b.r,a.v+b.v);}
Com operator -(const Com &a,const Com &b) {return Com(a.r-b.r,a.v-b.v);}
Com w[N<<2];
const double pi=acos(-1);
void FFT(Com *a,int d,int flag) {
	static int rev[N<<2];
	int n=1<<d;
	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int i=0;i<n;i++) w[i]=Com(cos(2*flag*i*pi/n),sin(2*flag*i*pi/n));
	for(int s=1;s<=d;s++) {
		int len=1<<s,mid=len>>1;
		for(int i=0;i<n;i+=len) {
			for(int j=0;j<mid;j++) {
				Com u=a[i+j],v=a[i+j+mid]*w[n/len*j];
				a[i+j]=u+v;
				a[i+j+mid]=u-v;
			}
		}
	}
	if(flag==-1) for(int i=0;i<n;i++) a[i]=a[i]/n;
}
Com f1[N<<2],f2[N<<2],g1[N<<2],g2[N<<2];
Com A[N<<2],B[N<<2],C[N<<2],D[N<<2];
void mul(ll *f,ll *g,int d,ll *ans) {
	int n=1<<d;
	for(int i=0;i<n;i++) {
		f1[i]=Com(f[i]/sqr,0),f2[i]=Com(f[i]%sqr,0);
		g1[i]=Com(g[i]/sqr,0),g2[i]=Com(g[i]%sqr,0);
	}
	for(int i=0;i<n<<1;i++) ans[i]=0;
	for(int i=n;i<n<<1;i++) f1[i]=f2[i]=g1[i]=g2[i]=Com(0,0);
	FFT(f1,d+1,1),FFT(f2,d+1,1),FFT(g1,d+1,1),FFT(g2,d+1,1);
	for(int i=0;i<n<<1;i++) {
		A[i]=f1[i]*g1[i];
		B[i]=f1[i]*g2[i];
		C[i]=f2[i]*g1[i];
		D[i]=f2[i]*g2[i];
	}
	FFT(A,d+1,-1),FFT(B,d+1,-1),FFT(C,d+1,-1),FFT(D,d+1,-1);
	for(int i=0;i<n;i++) ans[i]=(ll(A[i].r+0.5)*sqr%mod*sqr%mod+ll(B[i].r+0.5)*sqr+ll(C[i].r+0.5)*sqr+ll(D[i].r+0.5))%mod;
	for(int i=n;i<n<<1;i++) ans[i]=0;
}
ll ksm(ll t,ll x) {
	ll ans=1;
	for(;x;x>>=1,t=t*t%mod)
		if(x&1) ans=ans*t%mod;
	return ans;
}

int n,m;
ll tem[N<<2];
void inverse(ll *a,ll *f,int len) {
	if(len==1) {return f[0]=ksm(a[0],mod-2),void();}
	int d=ceil(log2(len));
	inverse(a,f,len>>1);
	mul(a,f,d,tem);
	mul(tem,f,d,tem);
	for(int i=0;i<len;i++) f[i]=(2*f[i]-tem[i]+mod)%mod;
}

ll inv[N<<2];
void Ln(ll *a,ll *ans,int len) {
	for(int i=0;i<len-1;i++) ans[i]=a[i+1]*(i+1)%mod;
	int d=ceil(log2(len));
	inverse(a,inv,len);
	ans[len-1]=0;
	mul(ans,inv,d,ans);
	for(int i=len-1;i;i--) ans[i]=ans[i-1]*ksm(i,mod-2)%mod;
	ans[0]=0;
}

ll f[N<<2];
int main() {
	n=Get(),mod=Get();
	for(int i=1;i<=n;i++) a[i]=Get();
	
	a[0]=1;
	int d=ceil(log2(n+1));
	Ln(a,f,1<<d);
	for(int i=1;i<=n;i++) f[i]=f[i]*i%mod;
	int tot=0;
	for(int i=1;i<=n;i++) {
		for(int j=i+i;j<=n;j+=i) {
			f[j]=(f[j]-f[i]+mod)%mod;
		}
		if(f[i]) tot++;
	}
	cout<<tot<<"\n";
	for(int i=1;i<=n;i++) if(f[i]) cout<<i<<" ";
	return 0;
}

posted @ 2019-02-20 19:57  hec0411  阅读(265)  评论(0编辑  收藏  举报