Loading

题解 P4463 【[集训队互测2012] calc】

题意

一个序列\(a_1,a_2,\ldots,a_n\)合法,当且仅当:

  • \(\forall i\in[1,n],a_i\in[1,k]\)
  • \(\forall i\not=j,a_i\not=a_j\)

一个序列的值为它的乘积,即\(\prod_{i=1}^na_i\)

给定\(n,k\),求所有不同的序列的值的和。

题解

不难发现答案可以写成生成函数的形式:

\[F(x)=\prod_{i=1}^k(1+ix) \]

但此时是无序的,因为\(a_1,a_2,\ldots,a_n\)互不相等,乘\(n!\)即可。最后答案为\(n![x^n]F(x)\)

两边同取\(\ln\)把积变和。

\[\ln F(x)=\sum_{i=1}^k\ln(1+ix) \]

左边这个东西是可以化为多项式的。具体的:

\[\begin{aligned} \ln(1+ix)=&\int(\ln(1+ix))^\prime dx &\text{(求导再求原)}\\ =&\int \frac{i}{1+ix}dx &\text{(复合函数求导)}\\ =&\int i\sum_{j=0}^\infty (-ix)^jdx &\text{(还原等比数列求和)}\\ =&\int i\sum_{j=0}^\infty (-1)^ji^jx^jdx &\text{(拆括号)}\\ =&\sum_{j=0}^\infty (-1)^ji^{j+1}\int x^jdx &\text{(交换求和顺序)}\\ =&\sum_{j=0}^\infty (-1)^ji^{j+1}\frac{x^{j+1}}{j+1} &\text{(求原函数)}\\ =&\sum_{j=1}^\infty(-1)^{j-1}i^j\frac{x^j}{j} &\text{(用j-1替换j)} \end{aligned} \]

于是乎带进去。

\[\ln F(x)=\sum_{i=1}^k\sum_{j=1}^\infty(-1)^{j-1}i^j\frac{x^j}{j} \]

整理一下下

\[\sum_{j=1}^\infty\frac{(-1)^{j-1}}{j}(\sum_{i=1}^ki^j)x^j \]

\(f_y(x)=\sum_{i=1}^xi^y\)

\[\sum_{j=1}^\infty\frac{(-1)^{j-1}}{j}f_j(k)x^j \]

根据经验我们知道\(f_y(x)=\sum_{i=1}^xi^y\)是一个关于\(x\)\(y+1\)次多项式,可以用拉格朗日插值得到。

\[f(x)=\sum_{i=0}^{n}y_i\prod_{i\not=j}\frac{x-x_j}{x_i-x_j} \]

此时求\(f_y(x)\)的复杂度为\(\mathcal{O}(y^2)\),但由于点是我们自己取的,取连续的点可以做到\(\mathcal{O}(y)\)。我们已知\(i\in[0,n]\)时的取值\(y_i\),那么魔改公式:

\[f(k)=\sum_{i=0}^ny_i\prod_{i\not=j}\frac{k-j}{i-j} \]

\[pre_i=\prod_{j=0}^{i}k-j,suf_i=\prod_{j=i}^nk-j \]

\[f(k)=\sum_{i=0}^ny_i\frac{pre_{i-1}\times suf_{i+1}}{i!(n-i)!(-1)^{n-i}} \]

注意符号。

算出系数后在\(\exp\)一下就好了。

然后就有\(\mathcal{O}(n^2+n\log n)\)的算法了。

只不过需要MTT

代码

#include<bits/stdc++.h>
namespace in{
	char buf[1<<21],*p1=buf,*p2=buf;
	inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
	template <typename T>inline void read(T& t){
		t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
	    while(isdigit(ch)){t=t*10+ch-48;ch = getc();}if(f)t=-t;
	}
	template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
}
namespace out{
	char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
	inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
	inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
	template <typename T>void write(T x) {
		static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
		while (len>=0)putc(buf[len]),--len;
	}
}
typedef std::complex<double>complex;
const int N=4e6+10;const double PI=acos(-1);const complex I=complex(0,1);
int rev[N];complex Wn[N];int M,mod;
struct modint{
    int x;
    modint(int o=0){x=o;}
    modint &operator = (int o){return x=o,*this;}
    modint &operator +=(modint o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
    modint &operator -=(modint o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
    modint &operator *=(modint o){return x=1ll*x*o.x%mod,*this;}
    modint &operator ^=(int b){
        modint a=*this,c=1;
        for(;b;b>>=1,a*=a)if(b&1)c*=a;
        return x=c.x,*this;
    }
    modint &operator /=(modint o){return *this *=o^=mod-2;}
    modint &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
    modint &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
    modint &operator *=(int o){return x=1ll*x*o%mod,*this;}
    modint &operator /=(int o){return *this *= ((modint(o))^=mod-2);}
	template<class I>friend modint operator +(modint a,I b){return a+=b;}
    template<class I>friend modint operator -(modint a,I b){return a-=b;}
    template<class I>friend modint operator *(modint a,I b){return a*=b;}
    template<class I>friend modint operator /(modint a,I b){return a/=b;}
    friend modint operator ^(modint a,int b){return a^=b;}
    friend bool operator ==(modint a,int b){return a.x==b;}
    friend bool operator !=(modint a,int b){return a.x!=b;}
    bool operator ! () {return !x;}
    modint operator - () {return x?mod-x:0;}
	modint &operator++(int){return *this+=1;}
};
inline long long num(complex x){double d=x.real();return d<0?(long long)(d-0.5)%mod:(long long)(d+0.5)%mod;}
struct poly{
	std::vector<complex>a0,a1;
	int size(){return a0.size();}
    void resize(int n){a0.resize(n);a1.resize(n);}
	void set(int x,modint y){
		a0[x]=y.x/M;
		a1[x]=y.x%M;
	}
	long long get(int x){return (M*M*num(a0[x].real())%mod +
				M*(num(a0[x].imag())+num(a1[x].real()))%mod+num(a1[x].imag()))%mod;}
	modint val(int x){return (long long)(M*a0[x].real()+a1[x].real()+mod)%mod;}
};
poly operator+(poly a,poly b){
    int n=std::max(a.size(),b.size());a.resize(n),b.resize(n);
    for(int i=0;i<n;i++)a.set(i,a.val(i)+b.val(i));return a;
}
poly operator-(poly a,poly b){
    int n=std::max(a.size(),b.size());a.resize(n),b.resize(n);
    for(int i=0;i<n;i++)a.set(i,a.val(i)-b.val(i));return a;
}
inline poly one(){poly a;a.resize(1);a.set(0,1);return a;}
inline int ext(int n){int k=0;while((1<<k)<n)k++;return k;}
inline void init(int k){
	int n=1<<k;
	for(int i=0;i<n;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	for(int i=0;i<n;i++)
		Wn[i]={cos(PI/n*i),sin(PI/n*i)};
}
void FFT(std::vector<complex>&A,int n,int t){
	if(t<0)for(int i=1;i<n;i++)if(i<(n-i))std::swap(A[i],A[n-i]);
	for(int i=0;i<n;i++)
		if(i<rev[i])std::swap(A[i],A[rev[i]]);
	for(int m=1;m<n;m<<=1)
		for(int i=0;i<n;i+=m<<1)
			for(int k=i;k<i+m;k++){
				complex W=Wn[1ll*(k-i)*n/m];
				complex a0=A[k],a1=A[k+m]*W;
				A[k]=a0+a1;A[k+m]=a0-a1; 
			}
	if(t<0)for(int i=0;i<n;i++)A[i]/=n;
}
void MTT(poly &A,int n,int t){
	for(int i=0;i<n;i++)A.a0[i]=A.a0[i]+I*A.a1[i];
	FFT(A.a0,n,t);
	for(int i=0;i<n;i++)A.a1[i]=std::conj(A.a0[i?n-i:0]);
	for(int i=0;i<n;i++){
		complex p=A.a0[i],q=A.a1[i];
		A.a0[i]=(p+q)*0.5;A.a1[i]=(q-p)*0.5*I;
	}
}
inline poly operator*(poly a,poly b){
    int n=a.size()+b.size()-1,k=ext(n);
    a.resize(1<<k),b.resize(1<<k),init(k);
    MTT(a,1<<k,1);MTT(b,1<<k,1);
	for(int i=0;i<(1<<k);i++){
		complex p=a.a0[i]*b.a0[i]+I*a.a1[i]*b.a0[i];
		complex q=a.a0[i]*b.a1[i]+I*a.a1[i]*b.a1[i];
		a.a0[i]=p,a.a1[i]=q;
	}
    FFT(a.a0,1<<k,-1);FFT(a.a1,1<<k,-1);a.resize(n);
	long long tmp;for(int i=0;i<n;i++)
		tmp=a.get(i),a.set(i,tmp);
	return a;
}
inline poly deriv(poly a){//求导 
    int n=a.size()-1;
    for(int i=0;i<n;i++)a.set(i,a.val(i+1)*(i+1));
    a.resize(n);return a;
}
inline poly inter(poly a){//求原 
    int n=a.size()+1;a.resize(n);
    for(int i=n;i>=1;i--)a.set(i,a.val(i-1)/i);
    a.set(0,0);return a;
}
poly inv(poly F,int k){
    int n=1<<k;F.resize(n);
    if(n==1){F.set(0,modint(1)/F.val(0));return F;}
    poly G,H=inv(F,k-1);
    G.resize(n),H.resize(n<<1),F.resize(n<<1);
    for(int i=0;i<n/2;i++)G.set(i,H.val(i)*2);
    H=H*H;H.resize(n);H=H*F;H.resize(n);
    G=G-H;return G;
}
inline poly inv(poly a){
    int n=a.size();
    a=inv(a,ext(n)),a.resize(n);return a;;
}
inline poly ln(poly a){
    int n=a.size();
    a=inter(deriv(a)*inv(a));
    a.resize(n);return a;
}
poly exp(poly a,int k){
	int n=1<<k;a.resize(n);
	if(n==1)return one();
	poly f0=exp(a,k-1);f0.resize(n);
	return f0*(one()+a-ln(f0)); 
}
poly exp(poly a){
	int n=a.size();
	a=exp(a,ext(n));a.resize(n);return a;
}
modint y[1000];
modint *pre,*suf;modint aaa[2000];
modint fac[1000],ifac[1000];
modint get(int n,int k){
	//计算\sum_{i=1}^k i^n 显然为一个n+1次的多项式 需要n+2个点 
	modint sum=0;
	for(int i=0;i<=n+1;i++,sum+=(modint(i)^n))y[i]=sum;
	n=n+1;pre[-1]=suf[n+1]=1;
	for(int i=0;i<=n;i++)pre[i]=pre[i-1]*(k-i);
	for(int i=n;i>=0;i--)suf[i]=suf[i+1]*(k-i);
	sum=0;
	for(int i=0;i<=n;i++)
		if((n-i)&1)sum-=y[i]*pre[i-1]*suf[i+1]*ifac[i]*ifac[n-i];
		else sum+=y[i]*pre[i-1]*suf[i+1]*ifac[i]*ifac[n-i];
	return sum;
}
int n,k;
signed main(){
	in::read(k,n,mod);M=sqrt(mod);
	pre=&aaa[10];suf=&aaa[1010];ifac[0]=fac[0]=1;for(int i=1;i<1000;i++)fac[i]=fac[i-1]*i,ifac[i]=modint(1)/fac[i];
	poly F;F.resize(n+1);
	for(int i=1;i<=n;i++)
		if((i+1)&1)F.set(i,-get(i,k)/i);
		else F.set(i,get(i,k)/i);
	//for(int i=0;i<=n;i++)printf("%d ",F.val(i).x);puts("");
	F=exp(F);out::write((F.val(n)*fac[n]).x);
	out::flush();
	return 0;
}
posted @ 2021-01-23 22:06  caigou233  阅读(68)  评论(0编辑  收藏  举报