「HEOI 2016/TJOI 2016」求和
题目链接
\(Solution\)
先化简式子:
\[f(n)=\sum_{i=0}^n\sum_{j=0}^i\begin{Bmatrix} i \\ j \end {Bmatrix}*2^j*j!
\]
\[f(n)=\sum_{j=0}^n2^j*j!\sum_{i=0}^n\begin{Bmatrix} i \\ j \end {Bmatrix}
\]
根据第二类斯特林数的公式:
\[f(n)=\sum_{j=0}^n2^j*j!\sum_{i=0}^n\sum_{k=0}^j\frac{(-1)^k}{k!}*\frac{(j-k)^i}{(j-k)!}
\]
\[f(n)=\sum_{j=0}^n2^j*j!\sum_{k=0}^j\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^n(j-k)^i}{(j-k)!}
\]
我们令\(F(i)=\frac{(-1)^k}{i!},G(i)=\frac{\sum_{j=0}^ni^j}{i!}\)
根据等比数列求和公式可得:
\(\sum_{j=0}^ni^j=\frac{i^{n+1}-1}{i-1}\)
所以\(G(i)=\frac{i^{n+1}-1}{(i-1)*i!}\)
\(G(0)=1,G(1)=n+1\)
那么
\[f(n)=\sum_{j=0}^n2^j*j!\sum_{k=0}^jF(k)*G(j-k)
\]
因为当\(j>k\)时\(G(j-k)=0\)
所以原式等价于:
\[f(n)=\sum_{j=0}^n2^j*j!\sum_{k=0}^nF(k)*G(j-k)
\]
这个东西直接\(NTT\)搞一搞就好了
\(Code\)
#include<bits/stdc++.h>
#define int long long
#define rg register
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
const int mod=998244353;
const int N=500010;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
return f*x;
}
int f[N],g[N],r[N],limit=1,w[N],inv[N],a[N],b[N],jc[N];
int ksm(int a,int b){
int ans=1;
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
void ntt(int *a,int opt){
for(int i=0;i<=limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int i=1;i<limit;i<<=1){
int w=ksm(3,(mod-1)/(i*2));
if(opt==-1) w=ksm(w,mod-2);
for(int j=0;j<limit;j+=i<<1){
int l=1;
for(int k=j;k<j+i;k++){
int p=l*a[k+i]%mod;
a[k+i]=(a[k]-p+mod)%mod;
a[k]=(a[k]+p)%mod;
l=l*w%mod;
}
}
}
}
main(){
int n=read(),l=0,ans=0;
while(limit<=(n<<1))
limit<<=1,l++;
jc[0]=1;
for(int i=1;i<=limit;i++)
jc[i]=jc[i-1]*i%mod;
inv[limit]=ksm(jc[limit],mod-2);
for(int i=limit-1;i>=0;i--)
inv[i]=inv[i+1]*(i+1)%mod;
g[0]=1,g[1]=n+1,f[1]=mod-1,f[0]=1;
for(int i=2;i<=n;i++)
g[i]=(ksm(i,n+1)-1)%mod*inv[i]%mod*ksm(i-1,mod-2)%mod,f[i]=i&1?mod-inv[i]:inv[i];
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(f,1),ntt(g,1);
for(int i=0;i<limit;i++)
f[i]=f[i]*g[i]%mod;
ntt(f,-1);
int inv=ksm(limit,mod-2);
for(int i=0;i<=n;i++)
ans=(ans+ksm(2,i)*jc[i]%mod*f[i]%mod*inv%mod)%mod;
printf("%lld",ans);
}