[TJOI2016&&HEOI2016]求和
\[f(n)=\sum_{i=0}^{n}\sum_{j=0}^{i}S(i,j)*2^j*j!
\]
其中\(S(i,j)\)是第二类斯特林数
\(n\le10^5\),模\(998244353\)
sol
所以说后面两项到底是干什么的
把\(j\)提到前面去
\[f(n)=\sum_{j=0}^{n}2^j*j!\sum_{i=0}^{n}S(i,j)
\]
(\(i\)从\(0\)开始是没有问题的,因为当\(i<j\)的时候\(S(i,j)=0\))
我们知道
\[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k\binom{j}{k}(j-k)^i
\]
那么
\[\sum_{i=0}^{n}S(i,j)=\sum_{i=0}^{j}\frac{1}{j!}\sum_{k=0}^{j}(-1)^k\binom{j}{k}(j-k)^i\\=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k\binom{j}{k}\sum_{i=0}^{n}(j-k)^i
\]
发现后面的其实就是一个等比数列求和
直接用公式
\[\sum_{i=0}^{n}q^i=\frac{q^{n+1}-1}{q-1}
\]
注意特判\(p=0\)和\(p=1\)(注意\(0^0=1\))
然后直接上卷积啊
code
#include<cstdio>
#include<algorithm>
using namespace std;
const int _ = 400005;
const int mod = 998244353;
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
int fastpow(int a,int b)
{
int res=1;
while (b) {if (b&1) res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}
return res;
}
int n,N,jc[_],inv[_],a[_],b[_],rev[_],l,ans;
void NTT(int *P,int opt)
{
for (int i=0;i<N;++i) if (i>rev[i]) swap(P[i],P[rev[i]]);
for (int i=1;i<N;i<<=1)
{
int W=fastpow(3,(mod-1)/(i<<1));
if (opt==-1) W=fastpow(W,mod-2);
for (int j=0,p=i<<1;j<N;j+=p)
{
int w=1;
for (int k=0;k<i;++k,w=1ll*w*W%mod)
{
int x=P[j+k],y=1ll*P[j+k+i]*w%mod;
P[j+k]=(x+y)%mod;P[j+k+i]=(x-y+mod)%mod;
}
}
}
if (opt==-1)
{
int Inv=fastpow(N,mod-2);
for (int i=0;i<N;++i) P[i]=1ll*P[i]*Inv%mod;
}
}
int main()
{
n=gi();
jc[0]=inv[0]=1;
for (int i=1;i<=n;++i) jc[i]=1ll*jc[i-1]*i%mod;
inv[n]=fastpow(jc[n],mod-2);
for (int i=n-1;i;--i) inv[i]=1ll*inv[i+1]*(i+1)%mod;
for (int i=0;i<=n;++i) a[i]=i&1?mod-inv[i]:inv[i];
b[0]=1;b[1]=n+1;
for (int i=2;i<=n;++i) b[i]=1ll*(fastpow(i,n+1)-1+mod)%mod*fastpow(i-1,mod-2)%mod*inv[i]%mod;
for (N=1;N<=2*n;N<<=1) ++l;--l;
for (int i=0;i<N;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
NTT(a,1);NTT(b,1);
for (int i=0;i<N;++i) a[i]=1ll*a[i]*b[i]%mod;
NTT(a,-1);
for (int i=0,j=1;i<=n;++i,j=(j<<1)%mod) (ans+=1ll*a[i]*jc[i]%mod*j%mod)%=mod;
printf("%d\n",ans);
return 0;
}