bzoj 4555: [Tjoi2016&Heoi2016]求和【NTT】

暴力推式子推诚卷积形式,但是看好多blog说多项式求逆不知道是啥..

\[\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)*2^j*j! \]

\[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*C_j^k*(j-k)^i \]

\[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*\frac{j!}{k!(j-k)!}*(j-k)^i \]

\[\sum_{i=0}^{n}\sum_{j=0}^{n}\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*\frac{j!}{k!(j-k)!}*(j-k)^i*2^j*j! \]

\[\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^k*\frac{j!}{k!(j-k)!}*(j-k)^i*2^j \]

\[\sum_{j=0}^{n}2^j*j!\sum_{k=0}^{j}(-1)^k*\frac{1}{k!(j-k)!}*\sum_{i=0}^{n}(j-k)^i \]

\[\sum_{j=0}^{n}2^j*j!\sum_{k=0}^{j}\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!} \]

\[a[k]=\frac{(-1)^k}{k!},b[k]=\frac{\sum_{i=0}^{n}k^i}{k!} \]

\[\sum_{j=0}^{n}2^j*j!\sum_{k=0}^{j}a[k]*b[j-k] \]

于是就得到了卷积形式,可以上NTT了
顺便根据等比数列求和公式,\(\sum_{i=0}{i}kn=\frac{k^{n+1}-1}{k-1} \)

#include<iostream>
#include<cstdio>
using namespace std;
const int N=300005,mod=998244353,G=3;
int n,fac[N],inv[N],fi[N],a[N],b[N],re[N],lm,bt,ans;
int ksm(int a,int b)
{
	int r=1;
	while(b)
	{
		if(b&1)
			r=1ll*r*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}
	return r;
}
void dft(int a[],int f)
{
	for(int i=0;i<lm;i++)
		if(i<re[i])
			swap(a[i],a[re[i]]);
	for(int i=1;i<lm;i<<=1)
	{
		int wi=ksm(G,(mod-1)/(i<<1));
		if(f==-1)
			wi=ksm(wi,mod-2);
		for(int k=0;k<lm;k+=(i<<1))
		{
			int w=1,x,y;
			for(int j=0;j<i;j++)
			{
				x=a[j+k];
				y=1ll*w*a[i+j+k]%mod;
				a[j+k]=((x+y)%mod+mod)%mod;
				a[i+j+k]=((x-y)%mod+mod)%mod;
				w=1ll*w*wi%mod;
			}
		}
	}
	if(f==-1)
	{
		int ni=ksm(lm,mod-2);
		for(int i=0;i<lm;i++)
			a[i]=1ll*a[i]*ni%mod;
	}
}
void ntt()
{
	bt=1;
	for(;(1<<bt)<=2*n;bt++);
	lm=(1<<bt);
	for(int i=0;i<=lm;i++)
		re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	dft(a,1);
	dft(b,1);
	for(int i=0;i<lm;i++)
		a[i]=1ll*a[i]*b[i]%mod;
	dft(a,-1);
}
int main()
{
	scanf("%d",&n);
	inv[1]=1,fac[0]=fi[0]=1;
	for(int i=1;i<=n;i++)
	{
		if(i!=1)
			inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
		fac[i]=1ll*fac[i-1]*i%mod;
		fi[i]=fi[i-1]*inv[i]%mod;
	}
	a[0]=1;
	for(int i=1;i<=n;i++)
		a[i]=((i&1)?-1:1)*fi[i];
	b[0]=1,b[1]=n+1;
	for(int i=2;i<=n;i++)
		b[i]=1ll*(ksm(i,n+1)-1)*inv[i-1]%mod*fi[i]%mod;
	ntt();
	for(int i=0;i<=n;i++)
		ans=(ans+1ll*fac[i]*ksm(2,i)%mod*a[i]%mod)%mod;
	printf("%d",(ans%mod+mod)%mod);
	return 0;
}
posted @ 2018-02-26 14:37  lokiii  阅读(160)  评论(0编辑  收藏  举报