【XSY3906】数数题(期望,多项式)
题面
题解
让多项式完全入门、求导积分0基础的来做这个???
赶紧看来一波高中选修2-2,才来硬推这道题。
所以说有一些对于巨佬们来说很简单就能推出来的东西,我可能反而会用一些更加复杂的方法。
题目可以看成是对于所有的 \(k\) 求出:
化简一下:
将 \(\left(\sum\limits_{i=1}^n b_i\right)^k\) 多项式展开:
把 \(\prod\limits_{i=1}^n b_i^{y_i}\) 拆开,并交换一下求和符号:(这一步可能跳得有点多,但还是不详细阐述,因为自己推一下就可以了)
将最前面的 \(\Delta x^n\) 扔进去:
仔细观察 \(\sum\limits_{b_i=0}^{\frac{a_i}{\Delta x}} (b_i\Delta x)^{y_i}\Delta x\),根据定积分的定义,发现它就是:
所以原式:
以上就是将题目的要求化成一个跟无限无关的东西的过程。
写到这里才发现我的化简过程很繁琐,应该一开始就直接多项式展开然后再把积分套进去,这样就会简略很多。
将 \(\dbinom{k}{y_1,y_2,\cdots,y_n}\) 化简:
代入原式得:
设 \(F(x)=\sum\limits_{n=0}^{\infty}\dfrac{x^n}{(n+1)!}\),则 \(\dfrac{a_i^{y_i}}{(y_i+1)!}=[x^{y_i}]F(a_ix)\),原式即为:
考虑对这个式子先取 \(\ln\) 再取 \(\exp\)(显然式子值不变),得:
设 \(G(x)=\ln F(x)\),原式即为:
也就是说我们需要求出 \(\sum\limits_{i=1}^nG(a_ix)\) 的前 \(m\) 项系数。
显然 \(G(x)\) 我们是可以直接求出来的,设 \(G(x)=\sum\limits_{j=0}^{\infty}g_jx^j\),那么 \(G(a_ix)\) 的第 \(j\) 项系数为 \(g_ja_i^j\),那么 \(\sum\limits_{i=1}^nG(a_ix)\) 的第 \(j\) 项系数为 \(g_j\sum\limits_{i=1}^na_i^j\)。
所以我们需要对所有的 \(j=1\sim m\),求出 \(\sum\limits_{i=1}^na_i^j\)。
这是一个很经典的问题,相当于要求出 \(\sum\limits_{i=1}^n\dfrac{1}{1-a_ix}\)。(原因是 \(\dfrac{1}{1-a_ix}\) 第 \(j\) 项的系数就是 \(a_i^j\),这个自己做大除法模拟多项式求逆验证即可)
至于如何求 \(\sum\limits_{i=1}^n \dfrac{1}{1-a_ix}=\dfrac{1}{1-a_1x}+\dfrac{1}{1-a_2x}+\cdots+\dfrac{1}{1-a_nx}\),直接用分治 NTT 解决。
具体来说,假设当前要求 \(\sum\limits_{i=l}^r \dfrac{1}{1-a_ix}\),已经知道了 \(\dfrac{A}{B}=\sum\limits_{i=l}^{mid} \dfrac{1}{1-a_ix}\),\(\dfrac{C}{D}=\sum\limits_{i=mid+1}^{r} \dfrac{1}{1-a_ix}\),那么用 NTT 求出 \(\dfrac{A}{B}+\dfrac{C}{D}=\dfrac{AD+BC}{BD}\) 即可。
那么就可以求出 \(\sum\limits_{i=1}^nG(a_ix)\) 的前 \(m\) 项系数,那么 \(k![x^k]\exp\left(\sum\limits_{i=1}^nG(a_ix)\right)\) 的前 \(m\) 项系数(即答案)也就出来了。
代码如下:
#include<bits/stdc++.h>
#define LN 19
#define N 100010
#define lc (k<<1)
#define rc (k<<1|1)
#define ll long long
#define mod 998244353
using namespace std;
int n,m;
int rev[N<<2];
ll fac[N],ifac[N];
ll w[LN][N<<2][2];
ll a[N<<2],f[N<<2],g[N<<2];
ll ff[N<<2],daof[N<<2],invf[N<<2];
ll daog[N<<2],lng[N<<2];
ll poww(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
ll wn=poww(3,(mod-1)/len);
ll iwn=poww(wn,mod-2);
ll noww=1,nowiw=1;
for(int j=0;j<mid;j++,noww=noww*wn%mod,nowiw=nowiw*iwn%mod)
w[bit][j][0]=noww,w[bit][j][1]=nowiw;
}
}
void NTT(ll *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int len=mid<<1,i=0;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
ll x=a[i+j],y=w[bit][j][opt]*a[i+mid+j]%mod;
a[i+j]=(x+y)%mod,a[i+mid+j]=(x-y+mod)%mod;
}
}
}
if(opt)
{
ll div=poww(limit,mod-2);
for(int i=0;i<limit;i++)
a[i]=a[i]*div%mod;
}
}
void getinv(ll *f,ll *g,int n)
{
g[0]=poww(f[0],mod-2);
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
for(int i=0;i<now;i++) ff[i]=f[i];
NTT(ff,limit,1),NTT(g,limit,1);
for(int i=0;i<limit;i++)
g[i]=(2ll*g[i]%mod-ff[i]*g[i]%mod*g[i]%mod+mod)%mod;
NTT(g,limit,-1);
for(int i=now;i<limit;i++) g[i]=0;
}
for(int i=0;i<now;i++) ff[i]=0;
for(int i=n;i<now;i++) g[i]=0;
}
void getdao(ll *f,ll *g,int n)
{
g[n-1]=0;
for(int i=0;i<n-1;i++)
g[i]=(i+1)*f[i+1]%mod;
}
void getint(ll *f,ll *g,int n)
{
g[0]=0;
for(int i=1;i<n;i++)
g[i]=poww(i,mod-2)*f[i-1]%mod;
}
void getln(ll *f,ll *g,int n)
{
getdao(f,daof,n);
// printf("daof=");
// for(int i=0;i<n;i++)
// printf("%lld ",daof[i]);
// puts("");
getinv(f,invf,n);
// printf("invf=");
// for(int i=0;i<n;i++)
// printf("%lld ",invf[i]);
// puts("");
int limit=1;
while(limit<(n<<1)) limit<<=1;
NTT(daof,limit,1),NTT(invf,limit,1);
for(int i=0;i<limit;i++)
daog[i]=daof[i]*invf[i]%mod;
NTT(daog,limit,-1);
getint(daog,g,n);
for(int i=0;i<limit;i++)
daof[i]=invf[i]=daog[i]=0;
}
void getexp(ll *f,ll *g,int n)
{
g[0]=1;
int now=2;
for(;now<(n<<1);now<<=1)
{
int limit=now<<1;
getln(g,lng,now);
for(int i=0;i<now;i++) ff[i]=(f[i]-lng[i]+mod)%mod;
ff[0]=(ff[0]+1)%mod;
NTT(g,limit,1),NTT(ff,limit,1);
for(int i=0;i<limit;i++)
g[i]=g[i]*ff[i]%mod;
NTT(g,limit,-1);
for(int i=0;i<limit;i++) ff[i]=0;
}
for(int i=0;i<now;i++) lng[i]=0;
}
ll fzl[N<<2],fzr[N<<2],fml[N<<2],fmr[N<<2],A[N<<2],B[N<<2];
vector<ll>fz[N<<2],fm[N<<2];
void solve(int k,int l,int r)
{
if(l==r)
{
fz[k].push_back(1);
fm[k].push_back(1);
fm[k].push_back(mod-a[l]);
return;
}
int mid=(l+r)>>1;
solve(lc,l,mid),solve(rc,mid+1,r);
int limit=1;
while(limit<=(r-l+1)) limit<<=1;
for(int i=0,size=fz[lc].size();i<size;i++) fzl[i]=fz[lc][i];
for(int i=0,size=fz[rc].size();i<size;i++) fzr[i]=fz[rc][i];
for(int i=0,size=fm[lc].size();i<size;i++) fml[i]=fm[lc][i];
for(int i=0,size=fm[rc].size();i<size;i++) fmr[i]=fm[rc][i];
NTT(fzl,limit,1),NTT(fzr,limit,1),NTT(fml,limit,1),NTT(fmr,limit,1);
for(int i=0;i<limit;i++)
{
A[i]=(fzl[i]*fmr[i]%mod+fzr[i]*fml[i]%mod)%mod;
B[i]=fml[i]*fmr[i]%mod;
}
NTT(A,limit,-1),NTT(B,limit,-1);
for(int i=0;i<=r-l+1;i++)
fz[k].push_back(A[i]),fm[k].push_back(B[i]);
for(int i=0;i<limit;i++) fzl[i]=fzr[i]=fml[i]=fmr[i]=A[i]=B[i]=0;
fz[lc].clear(),fz[rc].clear(),fm[lc].clear(),fm[rc].clear();
}
void print(ll *a,int l,int r,string s)
{
cout<<s<<"=";
for(int i=l;i<=r;i++)
cout<<a[i]<<" ";
cout<<endl;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
int limit=1;
while(limit<=(m<<1)) limit<<=1;
init(limit);
fac[0]=1;
for(int i=1;i<=m+1;i++) fac[i]=fac[i-1]*i%mod;
ifac[m+1]=poww(fac[m+1],mod-2);
for(int i=m;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
for(int i=0;i<=m;i++) f[i]=ifac[i+1];
getln(f,g,m+1);
solve(1,1,n);
for(int i=0;i<=n;i++) fzl[i]=fz[1][i],fml[i]=fm[1][i];
getinv(fml,A,m+1);
NTT(fzl,limit,1),NTT(A,limit,1);
for(int i=0;i<limit;i++)
A[i]=A[i]*fzl[i]%mod;
NTT(A,limit,-1);
for(int i=0;i<=m;i++)
g[i]=g[i]*A[i]%mod;
for(int i=0;i<limit;i++) A[i]=0;
getexp(g,A,m+1);
for(int i=1;i<=m;i++)
printf("%lld ",fac[i]*A[i]%mod);
return 0;
}