BZOJ4555: [Tjoi2016&Heoi2016]求和
把题目式子结合第二类斯特林数通项公式化一化就变成NTT模板题了:
\[f(n)=\sum_{i=0}^n\sum_{j=0}^iS(i,j)*2^j*j!$$ $$=\sum_{i=0}^n\sum_{j=0}^nS(i,j)*2^j*j!$$ $$=\sum_{i=0}^n\sum_{j=0}^n2^j\sum_{k=0}^j(-1)^k\binom{j}{k}(j-k)^i$$ $$=\sum_{i=0}^n\sum_{j=0}^n2^j\sum_{k=0}^j(-1)^k\frac{j!}{k!(j-k)!}(j-k)^i
\]
\[=\sum_{j=0}^n2^j*j!\sum_{k=0}^{j}\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!}
\]
\[g(x)=\frac{(-1)^x}{x!},h(x)=\frac{\sum_{i=0}^nx^i}{x!}
\]
\[f(x)=\sum_{j=0}^n2^j*j!\sum_{k=0}^{j}g(k)*h(j-k)
\]
\[c(j)=\sum_{k=0}^{j}g(k)*h(j-k)
\]
\[f(x)=\sum_{j=0}^n2^j*j!*c(j)
\]
#include <cstdio>
#include <algorithm>
using namespace std;
const int pps=998244353,g=3,maxn=262145;
int n,m=1,lg;
int G[maxn],H[maxn],C[maxn];
int bin[maxn],fac[maxn],inv[maxn],rev[maxn];
int read() {
int x=0,f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
return x*f;
}
int quick(int a,int b) {
int sum=1;
while(b) {
if(b&1)sum=1ll*a*sum%pps;
a=1ll*a*a%pps;b>>=1;
}
return sum;
}
void prepare() {
H[1]=n+1;
bin[0]=fac[0]=H[0]=G[0]=inv[0]=1;
for(int i=1;i<=n;i++)
bin[i]=1ll*bin[i-1]*2%pps;
for(int i=1;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%pps;
G[n]=inv[n]=quick(fac[n],pps-2);
for(int i=n-1;i;i--)
G[i]=inv[i]=1ll*inv[i+1]*(i+1)%pps;
for(int i=1;i<=n;i+=2)G[i]=pps-G[i];
for(int i=2;i<=n;i++)
H[i]=1ll*inv[i]*(quick(i,n+1)-1)%pps*quick(i-1,pps-2)%pps;
}
void NTT(int *a,int sign) {
for(int i=0;i<m;i++)
if(rev[i]>i)swap(a[i],a[rev[i]]);
for(int s=2;s<=m;s<<=1) {
int gn=quick(g,((pps-1)/s*sign+pps-1)%(pps-1));
for(int i=0;i<m;i+=s) {
int w=1;
for(int j=0;j<(s>>1);j++,w=1ll*w*gn%pps) {
int x=a[i+j]%pps,y=1ll*w*a[i+(s>>1)+j]%pps;
a[i+j]=(x+y)%pps,a[i+(s>>1)+j]=(x-y+pps)%pps;
}
}
}
if(sign==1)return;
int invm=quick(m,pps-2);
for(int i=0;i<m;i++)
a[i]=1ll*a[i]*invm%pps;
}
int main() {
n=read();
prepare();
for(;m<=n*2;m<<=1,lg++);
for(int i=0;i<m;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
NTT(G,1);NTT(H,1);
for(int i=0;i<m;i++)
C[i]=1ll*G[i]*H[i]%pps;
NTT(C,-1);int ans=1;
for(int i=1;i<=n;i++)
ans=(ans+1ll*bin[i]*fac[i]%pps*C[i]%pps)%pps;
printf("%d\n",ans);
return 0;
}