玩游戏

又遇到奇怪的错误了,调了好久啊。

一、题目

点此看题

二、解法

首先写出最基础的答案柿子,对于 \(p\in[1,t]\) 答案是这样的:

\[c_p=\sum_{i=1}^n\sum_{j=1}^{m}(a_i+b_j)^p \]

然后考虑二项式展开来化简柿子:

\[c_p=\sum_{i=1}^n\sum_{j=1}^m\sum_{k=0}^{p}a_i^{k}\cdot b_{j}^{p-k}\cdot C(p,k) \]

把组合数拆开然后分配一下各项,可以进一步化简柿子:

\[c_p=p!\sum_{i=0}^n(a_i^k\cdot\frac{1}{k!})\sum_{j=0}^m(b_j^k\cdot\frac{1}{(p-k)!}) \]

上面的柿子似乎可以看成卷积,定义两个生成函数 \(A(x)\)\(B(x)\) 使他们卷积得到答案:

\[A(x)=\sum_{r=0}^\infty x^r\sum_{i=1}^na_i^r,B(x)=\sum_{r=0}^\infty x^r\sum_{i=1}^mb_i^r \]

求出上面的生成函数后直接转 \(\tt EGF\) 即可,不过我们要优化一下生成函数的计算,套路就是先用闭形式化简柿子

\[A(x)=\sum_{i=1}^n\sum_{r=0}^\infty a_i^r\cdot x^r=\sum_{i=1}^n\frac{1}{1-a_ix} \]

但是这个闭形式显然求不了和的,就是因为他是次数为负。可以通过一些奇怪的操作把它转成正次数就好算一些,这道题我们利用到了 \(\ln(x)'=\frac{1}{x}\) 这个性质,先构造出对应的分式以便于化出对数求导形式:

\[A(x)=n-x\sum_{i=1}^n\frac{-a_i}{1-a_ix}=n-x\sum_{i=1}^n\ln(1-a_ix)' \]

\(n\) 次导显然是不行的,我们再略微的变形来平衡下复杂度:

\[A(x)=n-x(\ln\prod_{i=1}^n1-a_ix)' \]

分治乘可以把里面的柿子 \(O(n\log^2 n)\) 可以算,然后 \(O(n\log n)\) 取对数之后 \(O(n)\) 求导,\(B\) 也以同样的方式算出来后 \(O(n\log n)\) 暴力乘就行了。

有一个很奇葩的错误,直接写在本子里了。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int M = 400005;
const int MOD = 998244353;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,k,t,len,a[M],b[M],c[M],fac[M],inv[M];
int A[M],B[M],F[M],G[M],rev[M];
int qkpow(int a,int b)
{
	int r=1;
	while(b>0)
	{
		if(b&1) r=r*a%MOD;
		a=a*a%MOD;
		b>>=1;
	}
	return r;
}
void NTT(int *a,int len,int tmp)
{
	
	for(int i=0;i<len;i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)*(len/2));
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	}
	for(int s=2;s<=len;s<<=1)
	{
		int t=s/2,w=(tmp==1)?qkpow(3,(MOD-1)/s):qkpow(3,(MOD-1)-(MOD-1)/s);
		for(int i=0;i<len;i+=s)
		{
			for(int j=0,x=1;j<t;j++,x=x*w%MOD)
			{
				int fe=a[i+j],fo=a[i+j+t];
				a[i+j]=(fe+x*fo)%MOD;
				a[i+j+t]=((fe-x*fo)%MOD+MOD)%MOD;
			}
		}
	}
	if(tmp==1) return ;
	int inv=qkpow(len,MOD-2);
	for(int i=0;i<len;i++)
		a[i]=a[i]*inv%MOD; 
}
void init(int n)
{
	fac[0]=inv[0]=inv[1]=1;
	for(int i=2;i<=n;i++) inv[i]=(MOD-MOD/i)*inv[MOD%i]%MOD;
	for(int i=2;i<=n;i++) inv[i]=inv[i-1]*inv[i]%MOD;
	for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD;
}
void solve(int *a,int *b,int x,int l,int r)
{
	int ln=r-l+1; 
	if(l==r)
	{
		a[0]=1;a[1]=MOD-b[l];
		return ;
	}
	int ls=x<<1,rs=x<<1|1,mid=(l+r)>>1;
	int f[3*ln]={},g[3*ln]={};
	solve(f,b,ls,l,mid);
	solve(g,b,rs,mid+1,r);
	len=1;while(len<=ln) len<<=1;
	NTT(f,len,1);NTT(g,len,1);
	for(int i=0;i<len;i++) f[i]=f[i]*g[i]%MOD;
	NTT(f,len,-1);
	for(int i=0;i<=ln;i++) a[i]=f[i];
}
void work(int n,int *a,int *b)
{
	len=1;
	while(len<2*n) len<<=1;
	for(int i=0;i<len;i++) A[i]=B[i]=0;
	for(int i=0;i<n;i++) A[i]=a[i];
	for(int i=0;i<n/2;i++) B[i]=b[i];
	NTT(A,len,1);NTT(B,len,1);
	for(int i=0;i<len;i++)
		A[i]=((2*B[i]-B[i]*B[i]%MOD*A[i])%MOD+MOD)%MOD;
	NTT(A,len,-1);
	for(int i=0;i<n;i++) b[i]=A[i];
}
void get(int n,int *a,int *b)
{
	//memset(b,0,sizeof b);
	//直接上面那样写是错的!!! 
	b[0]=qkpow(a[0],MOD-2);int cur=1;
	while(cur<n)
	{
		cur<<=1;
		work(cur,a,b);
	}
}
void ln(int n,int *a,int *b)
{
	memset(c,0,sizeof c);
	get(n,a,c);
	for(int i=0;i<n;i++)//求导
		a[i]=(i+1)*a[i+1]%MOD;
	len=1;while(len<2*n) len<<=1;
	NTT(a,len,1);NTT(c,len,1);
	for(int i=0;i<len;i++) a[i]=a[i]*c[i]%MOD;
	NTT(a,len,-1);
	for(int i=0;i<n;i++)//求积分 
		b[i+1]=a[i]*qkpow(i+1,MOD-2)%MOD;
}
signed main()
{
	init(200000);
	n=read();m=read();
	for(int i=1;i<=n;i++)
		a[i]=read();
	for(int i=1;i<=m;i++)
		b[i]=read();
	t=read();
	k=max(n,max(m,t));
	solve(F,a,1,1,n);
	solve(G,b,1,1,m);
	//多项式取对 
	ln(k+1,F,a);
	ln(k+1,G,b);
	//求导 
	for(int i=0;i<=k;i++)
	{
		a[i]=(i+1)*a[i+1]%MOD;
		b[i]=(i+1)*b[i+1]%MOD;
	}
	a[k+1]=b[k+1]=0;
	//求-x 
	for(int i=k;i>0;i--)
	{
		a[i]=MOD-a[i-1];
		b[i]=MOD-b[i-1];
	}
	a[0]=n;b[0]=m;
	//转成EGF 
	for(int i=2;i<=k;i++)
	{
		a[i]=a[i]*inv[i]%MOD;
		b[i]=b[i]*inv[i]%MOD;
	}
	len=1;while(len<2*k+2) len<<=1;
	NTT(a,len,1);NTT(b,len,1);
	for(int i=0;i<len;i++) a[i]=a[i]*b[i]%MOD;
	NTT(a,len,-1);
	int tmp=qkpow(n*m%MOD,MOD-2);
	for(int i=1;i<=t;i++)
		printf("%lld\n",a[i]*tmp%MOD*fac[i]%MOD);
}
posted @ 2021-03-10 21:40  C202044zxy  阅读(144)  评论(0编辑  收藏  举报