bzoj 4555 NTT优化子集斯特林
题目大意
读入n
求\(f(n)=\sum_{i=0}^n\sum_{j=0}^i\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)
分析
\(f(n)=\sum_{i=0}^n\sum_{j=0}^i\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)
因为斯特林三角中\(j>i\)时值为0,j枚举上界可以改为n
\(f(n)=\sum_{i=0}^n\sum_{j=0}^n\left\{\begin{matrix}i \\ j\end{matrix}\right\}*2^j*j!\)
改下求和顺序
\(f(n)=\sum_{j=0}^n2^j*j!\sum_{i=0}^n\left\{\begin{matrix}i \\ j\end{matrix}\right\}\)
关于斯特林三角形总和公式的推导见我上一篇博客
solution
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL Q=998244353;
const int N=262144;
const int M=262145;
inline int rd(){
int x=0;bool f=1;char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
for(;isdigit(c);c=getchar()) x=x*10+c-48;
return f?x:-x;
}
int n;
int rev[N];
LL g;
LL fac[M];
LL ifac[M];
LL inv[M];
LL a[N];
LL b[N];
LL c[N];
LL pwr(LL x,LL tms,LL mod){
LL res=1;
for(;tms>0;tms>>=1){
if(tms&1) res=res*x%mod;
x=x*x%mod;
}
return res;
}
void NTT(LL *a,int fl){
int i,j,k;
LL Wn,W,u,v;
for(i=0;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(i=2;i<=N;i<<=1){
if(fl==1) Wn=pwr(g,(Q-1)/i,Q);
else Wn=pwr(inv[g],(Q-1)/i,Q);
for(j=0;j<N;j+=i){
for(W=1,k=j;k<j+i/2;k++,W=W*Wn%Q){
u=a[k];
v=a[k+i/2]*W%Q;
a[k]=(u+v)%Q;
a[k+i/2]=((u-v)%Q+Q)%Q;
}
}
}
if(fl==-1)
for(i=0;i<N;i++) a[i]=a[i]*inv[N]%Q;
}
bool judge(LL x,LL mm){
for(int i=2;i*i<=mm;i++)
if((mm-1)%i==0&&pwr(x,(mm-1)/i,mm)==1) return 0;
return 1;
}
LL getrt(LL mm){
if(mm==2)return 1;
for(int i=2;;i++)
if(judge(i,mm)) return i;
}
int main(){
int i,kd;
n=rd();
for(i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(N>>1):0);
for(inv[1]=1,i=2;i<M;i++) inv[i]=(Q-Q/i)*inv[Q%i]%Q;
for(fac[0]=1,i=1;i<M;i++) fac[i]=fac[i-1]*i%Q;
for(ifac[0]=1,i=1;i<M;i++) ifac[i]=ifac[i-1]*inv[i]%Q;
for(i=0;i<=n;i++){
kd=(i&1)?-1:1;
a[i]=((kd*ifac[i])%Q+Q)%Q;
}
b[0]=1;b[1]=n+1;
for(i=2;i<=n;i++){
b[i]=((pwr(i,n+1,Q)-1)%Q+Q)%Q*inv[i-1]%Q*ifac[i]%Q;
}
g=getrt(Q);
NTT(a,1);
NTT(b,1);
for(i=0;i<N;i++) c[i]=a[i]*b[i]%Q;
NTT(c,-1);
LL ans=0;
for(i=0;i<=n;i++)
ans=(ans+(pwr(2,i,Q)*fac[i]%Q*c[i]%Q))%Q;
printf("%lld\n",ans);
return 0;
}