「ARC154F」Dice Game

题目

点这里看题目。


你有一个 \(n\) 个面的骰子。每次抛骰子的时候,每个面出现的概率相等。

现在开始抛骰子。设 \(X\) 为每个面都被扔出至少一次时的抛骰子次数,你需要对于 \(1\le k\le m\),求出 \(E[X^k]\)。答案均对 \(998244353\) 取模。

所有数据满足 \(1\le n,m\le 2\times 10^5\)

分析

我们观察答案的形式:内层是普通幂,外层是期望。

通常而言,处理普通幂或是变成下降幂,或是变成复合 \(e^x\)​​。因为外层还有一个期望待解决,我们先试试写出概率生成函数,并复合 \(e^x\)​​。如果概率生成函数是 \(F(x)\)​​,则只需要计算为 \(F(e^x)\bmod x^{m+1}\)​​。

只需要普通幂,我们直接写 OGF:

\[\begin{aligned} F(x) &=n\sum_{k\ge 0}(n-1)!\left\{\begin{matrix}k\\n-1\end{matrix}\right\}\frac{x^{k+1}}{n^{k+1}}\\ &=(n-1)!x\cdot \sum_{k\ge 0}\left\{\begin{matrix}k\\n-1\end{matrix}\right\} \frac{x^k}{n^k}\\ &=(n-1)!x\cdot \prod_{k=1}^{n-1}\frac{\frac{x}{n}}{1-\frac{kx}{n}}\\ &=(n-1)x^n\cdot\prod_{k=1}^{n-1}\frac{1}{n-kx} \end{aligned} \]

Note.

此处用了一个不那么众所周知的东西——一列的第二类斯特林数斯特林数的 OGF 为:

\[\sum_{n\ge 0}\left\{\begin{matrix}n\\m\end{matrix}\right\}x^n =\prod_{k=1}^{m}\frac{x}{1-kx} \]

不过,这个可以直接由第二类斯特林数的递推式得到。

好消息,求第二类斯特林数不用写快速幂辣!

现在再来计算复合。不妨先考察有限多项式的情况:假如有一个 \(n\)​​​ 次多项式 \(F(x)=\sum_{k=0}^nf_kx^k\)​​​,​我们想要求出 \(F(e^x)\bmod x^{m+1}\)​。我们注意到,这个时候 EGF 和 OGF 之间没有本质区别,所以我们求一个 \(\sum_{j=0}^mx^j\sum_{k=0}^nf_kk^j\)​ 也行。交换和式,进行收拢,最终我们可以得到结果就是 \(\sum_{k=0}^n\frac{f_k}{1-kx}\bmod x^{m+1}\),可以 \(O(n\log^2 n+m\log m)\) 计算。

Remark.

现在我连交换和式都不会啦!

这道题里面多次出现了 EGF 和 OGF 的互相转换,在原题解里就出现了两次(计算概率生成函数也是靠转换得到的)。其核心就是 \(e^{nx}\leftrightarrow \frac{1}{1-nx}\)​​,两者在系数上仅仅差一个阶乘。而 OGF 在运算上(指多项式工业操作)会更简单,所以这样的转化有简化运算的功能


去网上查了一下,发现有一个叫做形式 Laplace-Borel 变换的东西可以完成一般的 EGF 和 OGF 的转换:

\(F(z)=\sum_{n\ge 0}f_nz^n,\hat F(z)=\sum_{n\ge 0}f_n\frac{z^n}{n!}\)​,则:

\[\begin{aligned} F(z)&=\int_0^\infty \hat F(tz)e^{-t}\,\mathrm dt\\ \hat F(z)&=\frac{1}{2\pi}\int_{-\pi}^\pi F(ze^{i\theta })e^{e^{i\theta}}\,\mathrm d\theta \end{aligned} \]

我的评价是:不懂,希望我永远不会用到。

最后需要注意的问题是:逻辑上来说,我们应该先对分母复合,对于 \(x^{m+1}\)​​​ 截断后再求逆。如果先求出 \(\bmod x^{m+1}\)​​​​ 意义下的逆元再复合,则有可能出问题,因为分母的逆不是有限多项式,直接复合 \(e^x\)​​ 也不是良定义的。

反正,最终可以在 \(O(n\log^2n+m\log m)\) 的时间内解决!

代码

#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>

#define rep( i, a, b ) for( int i = (a) ; i <= (b) ; i ++ )
#define per( i, a, b ) for( int i = (a) ; i >= (b) ; i -- )

const int mod = 998244353;
const int MAXN = ( 1 << 19 ) + 5;

template<typename _T>
inline void Read( _T &x ) {
	x = 0; char s = getchar(); bool f = false;
	while( ! ( '0' <= s && s <= '9' ) ) { f = s == '-', s = getchar(); }
	while( '0' <= s && s <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar(); }
	if( f ) x = -x;
}

template<typename _T>
inline void Write( _T x ) {
	if( x < 0 ) putchar( '-' ), x = -x;
	if( 9 < x ) Write( x / 10 );
	putchar( x % 10 + '0' );
}

typedef std :: vector<int> Poly;
typedef std :: pair<Poly, Poly> Frac;

Frac uni[MAXN];
Poly ply[MAXN];

int fac[MAXN], ifac[MAXN];

int N, M;

inline int Qkpow( int, int );
inline int Inv( const int &a ) { return Qkpow( a, mod - 2 ); }
inline int Mul( int x, const int &v ) { return 1ll * x * v % mod; }
inline int Sub( int x, const int &v ) { return ( x -= v ) < 0 ? x + mod : x; }
inline int Add( int x, const int &v ) { return ( x += v ) >= mod ? x - mod : x; }

inline int& MulEq( int &x, const int &v ) { return x = 1ll * x * v % mod; }
inline int& SubEq( int &x, const int &v ) { return ( x -= v ) < 0 ? ( x += mod ) : x; }
inline int& AddEq( int &x, const int &v ) { return ( x += v ) >= mod ? ( x -= mod ) : x; }

inline int Qkpow( int base, int indx ) {
	int ret = 1;
	while( indx ) {
		if( indx & 1 ) MulEq( ret, base );
		MulEq( base, base ), indx >>= 1;
	}
	return ret;
}

namespace Basics {
	const int L = 19, g = 3, phi = mod - 1;

	int w[MAXN];

	inline void NTTInit( const int &n = 1 << L ) {
		w[0] = 1, w[1] = Qkpow( g, phi >> L );
		rep( i, 2, n - 1 ) w[i] = Mul( w[i - 1], w[1] );
	}

	inline void DIF( int *coe, const int &n ) {
		int *wp, p, e, o;
		for( int s = n >> 1 ; s ; s >>= 1 )
			for( int i = 0 ; i < n ; i += s << 1 ) {
				p = ( 1 << L ) / ( s << 1 ), wp = w;
				for( int j = 0 ; j < s ; j ++, wp += p ) {
					e = coe[i + j], o = coe[i + j + s];
					coe[i + j] = Add( e, o );
					coe[i + j + s] = Mul( *wp, Sub( e, o ) );
				}
			}
	}

	inline void DIT( int *coe, const int &n ) {
		int *wp, p, k;
		for( int s = 1 ; s < n ; s <<= 1 )
			for( int i = 0 ; i < n ; i += s << 1 ) {
				p = ( 1 << L ) / ( s << 1 ), wp = w;
				for( int j = 0 ; j < s ; j ++, wp += p )
					k = Mul( *wp, coe[i + j + s] ),
					coe[i + j + s] = Sub( coe[i + j], k ),
					coe[i + j] = Add( coe[i + j], k );
			}
		std :: reverse( coe + 1, coe + n );
		int inv = Inv( n ); rep( i, 0, n - 1 ) MulEq( coe[i], inv );
	}
}

namespace PolyInv {
	int P[MAXN], U[MAXN], Q[MAXN], V[MAXN];

	void Newton( const int &n ) {
		if( n == 1 ) {
			U[0] = Inv( P[0] );
			return ;
		}
		int h = ( n + 1 ) >> 1, L; Newton( h );
		for( L = 1 ; L <= n + h - 2 ; L <<= 1 );
		rep( i, 0, L - 1 ) Q[i] = V[i] = 0;
		rep( i, 0, n - 1 ) Q[i] = P[i];
		rep( i, 0, h - 1 ) V[i] = U[i];
		Basics :: DIF( Q, L );
		Basics :: DIF( V, L );
		rep( i, 0, L - 1 ) V[i] = Mul( V[i], Sub( 2, Mul( Q[i], V[i] ) ) );
		Basics :: DIT( V, L );
		rep( i, h, n - 1 ) U[i] = V[i];
	}

	inline void PolyInv( int *ret, const int *A, const int &n ) {
		rep( i, 0, n - 1 ) P[i] = A[i];
		Newton( n );
		rep( i, 0, n - 1 ) ret[i] = U[i];
	}
}

namespace PolyOperation {
	int P[MAXN] = {}, Q[MAXN] = {};

	inline Poly operator + ( const Poly &a, const Poly &b ) {
		int n = a.size(), m = b.size();
		Poly ret( std :: max( n, m ), 0 );
		for( int i = 0 ; i < n || i < m ; i ++ ) {
			if( i < n ) AddEq( ret[i], a[i] );
			if( i < m ) AddEq( ret[i], b[i] );
		}
		return ret;
	}
	
	inline Poly operator * ( const Poly &a, const Poly &b ) {
		int n = a.size(), m = b.size(), L;
		for( L = 1 ; L <= n + m - 2 ; L <<= 1 );
		rep( i, 0, L - 1 ) P[i] = Q[i] = 0;
		rep( i, 0, n - 1 ) P[i] = a[i];
		rep( i, 0, m - 1 ) Q[i] = b[i];
		Basics :: DIF( P, L );
		Basics :: DIF( Q, L );
		rep( i, 0, L - 1 ) MulEq( P[i], Q[i] );
		Basics :: DIT( P, L );
		return std :: vector<int> ( P, P + n + m - 1 );
	}
	
	inline Poly Inv( const Poly &a ) {
		int n = a.size();
		rep( i, 0, n - 1 ) P[i] = a[i];
		PolyInv :: PolyInv( P, P, n );
		return Poly( P, P + n );
	}
}	

inline Frac operator + ( const Frac &a, const Frac &b ) {
	using namespace PolyOperation;

	return ( Frac ) { a.first * b.second + a.second * b.first, a.second * b.second };
}

Poly Multiply( const int &l, const int &r ) {
	using namespace PolyOperation;
	
	if( l > r ) return { 1 };
	if( l == r ) return ply[l];
	int mid = ( l + r ) >> 1;
	return Multiply( l, mid ) * Multiply( mid + 1, r );
}

Frac Plus( const int &l, const int &r ) {
	using namespace PolyOperation;
	
	if( l > r ) return { Poly(), ( Poly ) { 1 } };
	if( l == r ) return uni[l];
	int mid = ( l + r ) >> 1;
	return Plus( l, mid ) + Plus( mid + 1, r );
}

inline void Init( const int &n ) {
	Basics :: NTTInit();
	fac[0] = 1; rep( i, 1, n ) fac[i] = Mul( fac[i - 1], i );
	ifac[n] = Inv( fac[n] ); per( i, n - 1, 0 ) ifac[i] = Mul( ifac[i + 1], i + 1 );
}

int main() {
	using namespace PolyOperation;
	
	Read( N ), Read( M ), Init( std :: max( N, M ) );
	Poly E( M + 1, 0 ); E[0] = 1;
	rep( i, 1, M ) E[i] = Mul( E[i - 1], N );
	rep( i, 0, M ) MulEq( E[i], Mul( ifac[i], fac[N - 1] ) );

	rep( i, 1, N - 1 ) ply[i] = { N, mod - i };
	Poly fm( Multiply( 1, N - 1 ) );
	rep( i, 0, N - 1 ) uni[i] = { ( Poly ) { fm[i] }, ( Poly ) { 1, ( mod - i ) % mod } };
	Frac fr( Plus( 0, N - 1 ) );
	fr.first.resize( M + 1 ), fr.second.resize( M + 1 );
	fm = fr.first * Inv( fr.second ), fm.resize( M + 1 );
	rep( i, 0, M ) MulEq( fm[i], ifac[i] );
	fm = E * Inv( fm );
	rep( i, 1, M ) Write( Mul( fm[i], fac[i] ) ), putchar( '\n' );
	return 0;
}
posted @ 2023-01-29 19:30  crashed  阅读(115)  评论(0编辑  收藏  举报