【LGP4705】玩游戏
显然这个题的期望就是逗你玩的,我们算出来总贡献除以\(nm\)就好了
设\(ans_t=\sum_{i=1}^n\sum_{j=1}^n(a_i+b_j)^t\)
二项式定理展开一下
\[ans_t=t!\sum_{i=0}^t\frac{\sum_{j=1}^na_j^i}{i!}\frac{\sum_{j=1}^mb_j^{t-i}}{(t-i)!}
\]
我们构造两个多项式\(A,B\)
\[A(x)=\sum_{i=1}^na_i^x,B(x)=\sum_{i=1}^mb_i^x
\]
显然这两个多项式一卷就是答案了
现在的问题就是求\(A\)和\(B\)了
考虑一下生成函数
显然\(A\)的生成函数就是每一个\(a_i\)的生成函数的和
对于一个\(a_i\)其生成函数显然是\(\frac{1}{1-a_ix}\)
于是
\[A=\sum_{i=1}^n\frac{1}{1-a_ix}
\]
暴力加显然是不行的,我们考虑到这些个多项式尽管每一个都是分式但是次数都是\(1\),于是我们可以分治做这个加法,就是合并左右两边的时候先通分再做加法,复杂度是\(O(nlog^2n)\)的
代码
#include<vector>
#include<cstdio>
#include<cstring>
#define re register
#define pb push_back
#define max(a,b) ((a)>(b)?(a):(b))
const int maxn=262144+5;
const int mod=998244353;
const int G[2]={3,(mod+1)/3};
int a[maxn],b[maxn],c[maxn],d[maxn],g[maxn],h[maxn],H[maxn],K[maxn],C[maxn];
int n,len,rev[maxn],m,T,fac[maxn],__[2][100],ifac[maxn],inv[maxn],A[2][maxn>>1];
std::vector<int> q[2][maxn],p[2][maxn];
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
inline int ksm(int a,int b) {
int S=1;for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) S=1ll*S*a%mod;return S;
}
inline void NTT(int *f,int o) {
for(re int i=0;i<len;i++) if(i<rev[i]) std::swap(f[i],f[rev[i]]);
for(re int og1,p=0,i=2,ln=1;i<=len;i<<=1,ln<<=1,++p) {
if(!__[o][p]) og1=__[o][p]=ksm(G[o],(mod-1)/i);else og1=__[o][p];
for(re int t,og=1,l=0;l<len;l+=i,og=1)
for(re int x=l;x<l+ln;++x) {
t=1ll*og*f[x+ln]%mod;og=1ll*og*og1%mod;
f[x+ln]=(f[x]-t+mod)%mod,f[x]=(f[x]+t)%mod;
}
}
if(!o) return;
int Inv=ksm(len,mod-2);
for(re int i=0;i<len;i++) f[i]=1ll*f[i]*Inv%mod;
}
inline void Inv(int n,int *A,int *B) {
if(n==1) {B[0]=ksm(A[0],mod-2);return;}
Inv((n+1)>>1,A,B);
len=1;while(len<n+n-1) len<<=1;
for(re int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
for(re int i=0;i<n;i++) C[i]=A[i];
for(re int i=n;i<len;i++) C[i]=0;NTT(C,0),NTT(B,0);
for(re int i=0;i<len;i++) B[i]=(2ll*B[i]-1ll*C[i]*B[i]%mod*B[i]%mod+mod)%mod;
NTT(B,1);for(re int i=n;i<len;i++) B[i]=0;
}
void cdq(int l,int r,int o,int t) {
if(l==r) {q[o][t].pb(1);q[o][t].pb(mod-A[o][l]);p[o][t].pb(1);return;}
int mid=l+r>>1;cdq(l,mid,o,t<<1),cdq(mid+1,r,o,t<<1|1);
len=1;while(len<=r-l+1) len<<=1;
for(re int i=0;i<len;i++) g[i]=h[i]=a[i]=b[i]=c[i]=d[i]=0;
for(re int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
for(re int i=0;i<q[o][t<<1].size();i++) a[i]=q[o][t<<1][i];
for(re int i=0;i<p[o][t<<1].size();i++) c[i]=p[o][t<<1][i];
for(re int i=0;i<q[o][t<<1|1].size();i++) b[i]=q[o][t<<1|1][i];
for(re int i=0;i<p[o][t<<1|1].size();i++) d[i]=p[o][t<<1|1][i];
NTT(a,0),NTT(b,0),NTT(c,0),NTT(d,0);
for(re int i=0;i<len;i++) g[i]=1ll*a[i]*b[i]%mod;
for(re int i=0;i<len;i++) h[i]=(1ll*a[i]*d[i]%mod+1ll*b[i]*c[i]%mod)%mod;
NTT(g,1),NTT(h,1);
for(re int i=0;i<r-l+1;i++) p[o][t].pb(h[i]);
for(re int i=0;i<=r-l+1;i++) q[o][t].pb(g[i]);
}
int main() {
n=read(),m=read();inv[1]=1;ifac[0]=1;fac[0]=1;
for(re int i=1;i<=n;i++) A[0][i]=read();
for(re int i=1;i<=m;i++) A[1][i]=read();
cdq(1,n,0,1),cdq(1,m,1,1);T=read();
for(re int i=2;i<=T;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for(re int i=1;i<=T;i++) fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
memset(c,0,sizeof(c)),memset(d,0,sizeof(d));
for(re int i=0;i<p[0][1].size();i++) a[i]=p[0][1][i];
for(re int i=0;i<p[1][1].size();i++) b[i]=p[1][1][i];
for(re int i=0;i<q[0][1].size();i++) c[i]=q[0][1][i];
for(re int i=0;i<q[1][1].size();i++) d[i]=q[1][1][i];
Inv(T+1,c,H),Inv(T+1,d,K);
int U=max(n,m);U=max(U,T);
len=1;while(len<=T+U) len<<=1;
for(re int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
NTT(a,0),NTT(b,0),NTT(H,0),NTT(K,0);
for(re int i=0;i<len;i++) a[i]=1ll*a[i]*H[i]%mod;
for(re int i=0;i<len;i++) b[i]=1ll*b[i]*K[i]%mod;
NTT(a,1),NTT(b,1);
for(re int i=T+1;i<len;i++) a[i]=b[i]=0;
for(re int i=0;i<=T;i++) a[i]=1ll*a[i]*ifac[i]%mod;
for(re int i=0;i<=T;i++) b[i]=1ll*b[i]*ifac[i]%mod;
NTT(a,0),NTT(b,0);
for(re int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%mod;
NTT(a,1);
int Inv=ksm(1ll*n*m%mod,mod-2);
for(re int i=1;i<=T;i++)
printf("%d\n",1ll*Inv*a[i]%mod*fac[i]%mod);
return 0;
}
还有压行真是好看