UOJ#269. 【清华集训2016】如何优雅地求和

原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ269.html

题目传送门 - UOJ269

题意

  有一个多项式函数 $f(x)$,最高次幂为 $x^m$,定义变换 $Q$:

$$Q(f,n,x)=\sum_{k=0}^n f(k)\binom nk x^k(1−x)^{n−k}$$
  现在给定函数 $f$ 和 $n,x$,求 $Q(f,n,x)\mod {\rm 998244353}$。

  $f(x)$ 由 $0$~$m$ 的点值给出。

  $1\leq n\leq 10^9,1\leq m \leq 2\times 10^4, 0\leq a_i,x <998244353$

题解

  cly_none 太强了。

  考虑一个 $m$ 次多项式 $f(x)$ ,必然可以拆成一堆下降幂的和。(忽略系数)其中,最高次项是 $m$ 次项,所以转成下降幂之后,最高次项就是一个 $m$ 阶下降幂。

  对于 $f(x)$ 的某一个下降幂表示,设为 $x^\underline{k}$ ,那么,可以得到:

$$\begin{aligned} & \sum_{i=0}^n i^{\underline k} {n\choose i} x^i (1-x)^{n-i} \\ = & \sum_{i=k}^n i^{\underline k} \frac {n^{\underline k}}{i ^ {\underline k}} {n - k\choose i - k } x^i (1-x)^{n-i} \\ = & n^{\underline k} \sum_{i=k}^n {n - k\choose i - k } x^i (1-x)^{n-i} \\ = & n^{\underline k} x^k \sum_{i=0}^{n-k} {n - k\choose i} x^i (1-x)^{n-k-i} \\ = & n^{\underline k} x^k \end{aligned}$$

  于是这样就证明了题目要求的式子是一个关于 $n$ 的 $m$ 次多项式。

  于是只需要 FFT 一下,求出 $[0,m]$ 之间的整点的点值,然后插值来求答案。由于这些点值十分特殊,所以可以预处理阶乘来 $O(m)$ 求解。

  总的时间复杂度为 $O(m\log m)$ 。

代码

#include <bits/stdc++.h>
using namespace std;
const int N=1<<16,mod=998244353;
int read(){
	int x=0;
	char ch=getchar();
	while (!isdigit(ch))
		ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x;
}
int Pow(int x,int y){
	int ans=1;
	for (;y;y>>=1,x=1LL*x*x%mod)
		if (y&1)
			ans=1LL*ans*x%mod;
	return ans;
}
int n,m,x,a[N],A[N],B[N];
int Fac[N],Inv[N];
int w[N],R[N];
int C(int n,int m){
	if (m>n||m<0)
		return 0;
	return 1LL*Fac[n]*Inv[m]%mod*Inv[n-m]%mod;
}
void FFT(int a[],int n){
	for (int i=0;i<n;i++)
		if (R[i]<i)
			swap(a[R[i]],a[i]);
	for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
		for (int i=0;i<n;i+=(d<<1))
			for (int j=0;j<d;j++){
				int tmp=1LL*w[t*j]*a[i+j+d]%mod;
				a[i+j+d]=(a[i+j]+mod-tmp)%mod;
				a[i+j]=(a[i+j]+tmp)%mod;
			}
}
void Mul(int a[],int b[],int m){
	int n,d;
	for (n=1,d=0;n<=m*2+2;n<<=1,d++);
	for (int i=0;i<n;i++)
		R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
	w[0]=1,w[1]=Pow(3,(mod-1)/n);
	for (int i=2;i<n;i++)
		w[i]=1LL*w[i-1]*w[1]%mod;
	FFT(a,n);
	FFT(b,n);
	for (int i=0;i<n;i++)
		a[i]=1LL*a[i]*b[i]%mod;
	w[0]=1,w[1]=Pow(w[1],mod-2);
	for (int i=2;i<n;i++)
		w[i]=1LL*w[i-1]*w[1]%mod;
	FFT(a,n);
	int inv=Pow(n,mod-2);
	for (int i=0;i<n;i++)
		a[i]=1LL*a[i]*inv%mod;
}
int calc(int n){
	int ans=0;
	for (int k=0;k<=n;k++)
		ans=(1LL*a[k]*C(n,k)%mod*Pow(x,k)%mod*Pow(mod+1-x,n-k)%mod+ans)%mod;
	return ans;
}
int main(){
	n=read(),m=read(),x=read();
	for (int i=0;i<=m;i++)
		a[i]=read();
	for (int i=Fac[0]=Inv[0]=1;i<=m;i++){
		Fac[i]=1LL*Fac[i-1]*i%mod;
		Inv[i]=1LL*Inv[i-1]*Pow(i,mod-2)%mod;
	}
	if (n<=m)
		return printf("%d\n",calc(n)),0;
	for (int i=0;i<=m;i++){
		A[i]=1LL*a[i]*Inv[i]%mod*Pow(x,i)%mod;
		B[i]=1LL*Inv[i]*Pow(mod+1-x,i)%mod;
	}
	Mul(A,B,m);
	for (int i=0;i<=m;i++)
		A[i]=1LL*A[i]*Fac[i]%mod;
	int ans=0;
	for (int i=0;i<=m;i++){
		int t=1LL*A[i]*Inv[i]%mod*Inv[m-i]%mod;
		t=1LL*t*Pow(n+mod-i,mod-2)%mod;
		if ((m-i)&1)
			t=(mod-t)%mod;
		ans=(ans+t)%mod;
	}
	for (int i=n;i>=n-m;i--)
		ans=1LL*ans*i%mod;
	printf("%d",ans);
	return 0;
}

  

posted @ 2018-10-04 10:40  zzd233  阅读(653)  评论(0编辑  收藏  举报