[luogu4705] 玩游戏
题目链接
洛谷:https://www.luogu.org/problemnew/show/P4705
Solution
精神污染....这玩意比数树还难写...就是窝太菜了代码过于冗长然后调试还写了两三K
思路据说比较套路??反正窝是不会
我们可以很容易的把答案写出来:
\[ans_k=\frac{1}{nm}\sum_{i=1}^{n}\sum_{j=1}^m (a_i+b_j)^k
\]
我们忽略掉那个\(nm\)的系数,最后乘回来就好了,然后化简下:
\[\begin{align}
ans_k=&\sum_{i=1}^{n}\sum_{j=1}^{m}\sum_{x=0}^k\binom{k}{x}a_i^xb_i^{k-x}\\
=&\sum_{x=0}^k\binom{k}{x}\sum_{i=1}^{n}a_i^x\sum_{j=1}^{m}b_i^{k-x}
\end{align}
\]
显然这是个卷积形式,然后算法瓶颈就在如何求出\(f(x)=\sum_{i=1}^{n}a_i^x\)。
我们写出这个玩意的生成函数:
\[\begin{align}
F(x)=&\sum_{i=0}^{\infty} f(i)x^i\\
=&\sum_{i=0}^{\infty}\sum_{j=1}^{n}a_j^ix^i\\
=&\sum_{j=1}^{n}\sum_{i=0}^{\infty} a_j^ix^i\\
=&\sum_{j=1}^{n}\frac{1}{1-a_jx}
\end{align}
\]
注意到:
\[(\ln (1-ax))'=\frac{-a}{1-ax}
\]
即:
\[x(\ln(1-ax))'=1-\frac{1}{1-ax}
\]
也就是说:
\[F(x)=\sum_{i=1}^{n}1-x(\ln(1-a_ix))'=n-x\sum_{i=1}^{n}(\ln(1-a_ix))'
\]
注意到导数满足加法律,后面的在化一下就是:
\[\begin{align}
F(x)=&n-x\left(\sum_{i=1}^{n}\ln (1-a_ix)\right)'\\
=&n-x\left(\ln \prod_{i=1}^{n}(1-a_ix)\right)'
\end{align}
\]
注意到里面的连乘形式可以分治\(\rm FFT\)在\(O(n\log ^2 n)\)解决,然后再照着式子算一下就好了,需要写个多项式求\(\ln\),注意要把前面忽略的东西弄回去。
总复杂度\(O(n\log^2 n)\)。
代码大概还能凑合着看吧...
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
#define pii pair<int,int >
#define vec vector<int >
#define pb push_back
#define mp make_pair
#define fr first
#define sc second
#define FOR(i,l,r) for(register int i=l,r_##i=r;i<=r_##i;++i)
const int maxn = 1e6+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
int a[maxn],b[maxn],fac[maxn],ifac[maxn],inv[maxn];
int w[maxn],n,m,T,mxn,bit,N,tmp[15][maxn],pos[maxn];
int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}
int qpow(int aa,int x) {
int res=1;
for(;x;x>>=1,aa=mul(aa,aa)) if(x&1) res=mul(res,aa);
return res;
}
void clear(int *l,int *r) {
if(l>=r) return ;
while(l!=r) *l++=0;*l=0;
}
void ntt_init(int len) {
for(mxn=1;mxn<=len;mxn<<=1);
w[0]=1;w[1]=qpow(3,(mod-1)/mxn);
for(int i=2;i<=mxn;i++) w[i]=mul(w[i-1],w[1]);
inv[0]=inv[1]=fac[0]=ifac[0]=1;
for(int i=2;i<=mxn;i++) inv[i]=mul(mod-mod/i,inv[mod%i]);
for(int i=1;i<=mxn;i++) fac[i]=mul(fac[i-1],i);
for(int i=1;i<=mxn;i++) ifac[i]=mul(ifac[i-1],inv[i]);
}
void get(int len) {for(bit=0,N=1;N<=len;N<<=1,bit++);}
void get_pos() {for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));}
void ntt(int *r,int op) {
for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
int x=r[j+k],y=mul(r[i+j+k],w[k*d]);
r[j+k]=add(x,y),r[i+j+k]=del(x,y);
}
if(op==-1) {
reverse(r+1,r+N);int d=qpow(N,mod-2);
for(int i=0;i<N;i++) r[i]=mul(r[i],d);
}
}
void poly_inv(int *r,int *t,int len) {
if(len==1) return t[0]=qpow(r[0],mod-2),void();
poly_inv(r,t,len>>1);get(len);get_pos();
for(int i=0;i<len;i++) tmp[0][i]=r[i],tmp[1][i]=t[i];
clear(tmp[0]+len,tmp[0]+N);
clear(tmp[1]+len,tmp[1]+N);
ntt(tmp[0],1),ntt(tmp[1],1);
for(int i=0;i<N;i++) t[i]=del(mul(2,tmp[1][i]),mul(mul(tmp[0][i],tmp[1][i]),tmp[1][i]));
ntt(t,-1),clear(t+len,t+N);
}
void poly_der(int *r,int *t,int len) {
get(len);
for(int i=1;i<len;i++) t[i-1]=mul(r[i],i);
clear(t+len-1,t+N);
}
void poly_ln(int *r,int *t,int len) {
poly_inv(r,tmp[3],len);
poly_der(r,tmp[4],len);
get(len),get_pos();
clear(tmp[3]+len,tmp[3]+N),clear(tmp[4]+len,tmp[4]+N);
ntt(tmp[3],1),ntt(tmp[4],1);
for(int i=0;i<N;i++) tmp[5][i]=mul(tmp[3][i],tmp[4][i]);
ntt(tmp[5],-1);
for(int i=0;i<len;i++) t[i+1]=mul(inv[i+1],tmp[5][i]);t[0]=0;
clear(t+len+1,t+N);
}
vector<int > p[maxn];
void dc_fft(int *s,int x,int l,int r) { // divide and conquar of NTT
if(l==r) return p[x].resize(2),p[x][0]=1,p[x][1]=del(0,s[l]),void();
int mid=(l+r)>>1,ls=x<<1,rs=x<<1|1;
dc_fft(s,ls,l,mid),dc_fft(s,rs,mid+1,r);
int r1=mid-l+1,r2=r-mid;
for(int i=0;i<=r1;i++) tmp[6][i]=p[ls][i];
for(int i=0;i<=r2;i++) tmp[7][i]=p[rs][i];
get(r1+r2),get_pos();
clear(tmp[6]+r1+1,tmp[6]+N),clear(tmp[7]+r2+1,tmp[7]+N);
ntt(tmp[6],1),ntt(tmp[7],1);
FOR(i,0,N-1) tmp[8][i]=mul(tmp[6][i],tmp[7][i]);
ntt(tmp[8],-1);
p[x].resize(r1+r2+1);
for(int i=0;i<=r1+r2;i++) p[x][i]=tmp[8][i];
p[ls].clear(),p[rs].clear();
}
int A[maxn],B[maxn];
void solve(int *s,int *t,int len) {
dc_fft(s,1,1,len);get(len);int nn=N;
FOR(i,0,p[1].size()-1) tmp[9][i]=p[1][i];get(T);nn=N;
clear(tmp[9]+p[1].size(),tmp[9]+N);
p[1].clear();
poly_ln(tmp[9],tmp[10],nn);
poly_der(tmp[10],tmp[11],nn);
for(int i=1;i<nn;i++) t[i]=del(0,tmp[11][i-1]);t[0]=len;
for(int i=0;i<nn;i++) t[i]=mul(t[i],ifac[i]);
}
int main() {
read(n),read(m);FOR(i,1,n) read(a[i]);FOR(i,1,m) read(b[i]);read(T);
ntt_init((max(n+m,T))<<1);
solve(a,A,n);
solve(b,B,m);
get(T<<1),get_pos();
clear(A+T+1,A+N),clear(B+T+1,B+N);ntt(A,1),ntt(B,1);
for(int i=0;i<N;i++) A[i]=mul(A[i],B[i]);
ntt(A,-1);
for(int i=1;i<=T;i++) write(mul(mul(fac[i],A[i]),mul(inv[n],inv[m])));
return 0;
}