「ARC154F」Dice Game
题目
点这里看题目。
你有一个 \(n\) 个面的骰子。每次抛骰子的时候,每个面出现的概率相等。
现在开始抛骰子。设 \(X\) 为每个面都被扔出至少一次时的抛骰子次数,你需要对于 \(1\le k\le m\),求出 \(E[X^k]\)。答案均对 \(998244353\) 取模。
所有数据满足 \(1\le n,m\le 2\times 10^5\)。
分析
我们观察答案的形式:内层是普通幂,外层是期望。
通常而言,处理普通幂或是变成下降幂,或是变成复合 \(e^x\)。因为外层还有一个期望待解决,我们先试试写出概率生成函数,并复合 \(e^x\)。如果概率生成函数是 \(F(x)\),则只需要计算为 \(F(e^x)\bmod x^{m+1}\)。
只需要普通幂,我们直接写 OGF:
Note.
此处用了一个不那么众所周知的东西——一列的第二类斯特林数斯特林数的 OGF 为:
\[\sum_{n\ge 0}\left\{\begin{matrix}n\\m\end{matrix}\right\}x^n =\prod_{k=1}^{m}\frac{x}{1-kx} \]不过,这个可以直接由第二类斯特林数的递推式得到。
好消息,求第二类斯特林数不用写快速幂辣!
现在再来计算复合。不妨先考察有限多项式的情况:假如有一个 \(n\) 次多项式 \(F(x)=\sum_{k=0}^nf_kx^k\),我们想要求出 \(F(e^x)\bmod x^{m+1}\)。我们注意到,这个时候 EGF 和 OGF 之间没有本质区别,所以我们求一个 \(\sum_{j=0}^mx^j\sum_{k=0}^nf_kk^j\) 也行。交换和式,进行收拢,最终我们可以得到结果就是 \(\sum_{k=0}^n\frac{f_k}{1-kx}\bmod x^{m+1}\),可以 \(O(n\log^2 n+m\log m)\) 计算。
Remark.
现在我连交换和式都不会啦!这道题里面多次出现了 EGF 和 OGF 的互相转换,在原题解里就出现了两次(计算概率生成函数也是靠转换得到的)。其核心就是 \(e^{nx}\leftrightarrow \frac{1}{1-nx}\),两者在系数上仅仅差一个阶乘。而 OGF 在运算上(指多项式工业操作)会更简单,所以这样的转化有简化运算的功能。
去网上查了一下,发现有一个叫做形式 Laplace-Borel 变换的东西可以完成一般的 EGF 和 OGF 的转换:
设 \(F(z)=\sum_{n\ge 0}f_nz^n,\hat F(z)=\sum_{n\ge 0}f_n\frac{z^n}{n!}\),则:
\[\begin{aligned} F(z)&=\int_0^\infty \hat F(tz)e^{-t}\,\mathrm dt\\ \hat F(z)&=\frac{1}{2\pi}\int_{-\pi}^\pi F(ze^{i\theta })e^{e^{i\theta}}\,\mathrm d\theta \end{aligned} \]我的评价是:不懂,希望我永远不会用到。
最后需要注意的问题是:逻辑上来说,我们应该先对分母复合,对于 \(x^{m+1}\) 截断后再求逆。如果先求出 \(\bmod x^{m+1}\) 意义下的逆元再复合,则有可能出问题,因为分母的逆不是有限多项式,直接复合 \(e^x\) 也不是良定义的。
反正,最终可以在 \(O(n\log^2n+m\log m)\) 的时间内解决!
代码
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#define rep( i, a, b ) for( int i = (a) ; i <= (b) ; i ++ )
#define per( i, a, b ) for( int i = (a) ; i >= (b) ; i -- )
const int mod = 998244353;
const int MAXN = ( 1 << 19 ) + 5;
template<typename _T>
inline void Read( _T &x ) {
x = 0; char s = getchar(); bool f = false;
while( ! ( '0' <= s && s <= '9' ) ) { f = s == '-', s = getchar(); }
while( '0' <= s && s <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar(); }
if( f ) x = -x;
}
template<typename _T>
inline void Write( _T x ) {
if( x < 0 ) putchar( '-' ), x = -x;
if( 9 < x ) Write( x / 10 );
putchar( x % 10 + '0' );
}
typedef std :: vector<int> Poly;
typedef std :: pair<Poly, Poly> Frac;
Frac uni[MAXN];
Poly ply[MAXN];
int fac[MAXN], ifac[MAXN];
int N, M;
inline int Qkpow( int, int );
inline int Inv( const int &a ) { return Qkpow( a, mod - 2 ); }
inline int Mul( int x, const int &v ) { return 1ll * x * v % mod; }
inline int Sub( int x, const int &v ) { return ( x -= v ) < 0 ? x + mod : x; }
inline int Add( int x, const int &v ) { return ( x += v ) >= mod ? x - mod : x; }
inline int& MulEq( int &x, const int &v ) { return x = 1ll * x * v % mod; }
inline int& SubEq( int &x, const int &v ) { return ( x -= v ) < 0 ? ( x += mod ) : x; }
inline int& AddEq( int &x, const int &v ) { return ( x += v ) >= mod ? ( x -= mod ) : x; }
inline int Qkpow( int base, int indx ) {
int ret = 1;
while( indx ) {
if( indx & 1 ) MulEq( ret, base );
MulEq( base, base ), indx >>= 1;
}
return ret;
}
namespace Basics {
const int L = 19, g = 3, phi = mod - 1;
int w[MAXN];
inline void NTTInit( const int &n = 1 << L ) {
w[0] = 1, w[1] = Qkpow( g, phi >> L );
rep( i, 2, n - 1 ) w[i] = Mul( w[i - 1], w[1] );
}
inline void DIF( int *coe, const int &n ) {
int *wp, p, e, o;
for( int s = n >> 1 ; s ; s >>= 1 )
for( int i = 0 ; i < n ; i += s << 1 ) {
p = ( 1 << L ) / ( s << 1 ), wp = w;
for( int j = 0 ; j < s ; j ++, wp += p ) {
e = coe[i + j], o = coe[i + j + s];
coe[i + j] = Add( e, o );
coe[i + j + s] = Mul( *wp, Sub( e, o ) );
}
}
}
inline void DIT( int *coe, const int &n ) {
int *wp, p, k;
for( int s = 1 ; s < n ; s <<= 1 )
for( int i = 0 ; i < n ; i += s << 1 ) {
p = ( 1 << L ) / ( s << 1 ), wp = w;
for( int j = 0 ; j < s ; j ++, wp += p )
k = Mul( *wp, coe[i + j + s] ),
coe[i + j + s] = Sub( coe[i + j], k ),
coe[i + j] = Add( coe[i + j], k );
}
std :: reverse( coe + 1, coe + n );
int inv = Inv( n ); rep( i, 0, n - 1 ) MulEq( coe[i], inv );
}
}
namespace PolyInv {
int P[MAXN], U[MAXN], Q[MAXN], V[MAXN];
void Newton( const int &n ) {
if( n == 1 ) {
U[0] = Inv( P[0] );
return ;
}
int h = ( n + 1 ) >> 1, L; Newton( h );
for( L = 1 ; L <= n + h - 2 ; L <<= 1 );
rep( i, 0, L - 1 ) Q[i] = V[i] = 0;
rep( i, 0, n - 1 ) Q[i] = P[i];
rep( i, 0, h - 1 ) V[i] = U[i];
Basics :: DIF( Q, L );
Basics :: DIF( V, L );
rep( i, 0, L - 1 ) V[i] = Mul( V[i], Sub( 2, Mul( Q[i], V[i] ) ) );
Basics :: DIT( V, L );
rep( i, h, n - 1 ) U[i] = V[i];
}
inline void PolyInv( int *ret, const int *A, const int &n ) {
rep( i, 0, n - 1 ) P[i] = A[i];
Newton( n );
rep( i, 0, n - 1 ) ret[i] = U[i];
}
}
namespace PolyOperation {
int P[MAXN] = {}, Q[MAXN] = {};
inline Poly operator + ( const Poly &a, const Poly &b ) {
int n = a.size(), m = b.size();
Poly ret( std :: max( n, m ), 0 );
for( int i = 0 ; i < n || i < m ; i ++ ) {
if( i < n ) AddEq( ret[i], a[i] );
if( i < m ) AddEq( ret[i], b[i] );
}
return ret;
}
inline Poly operator * ( const Poly &a, const Poly &b ) {
int n = a.size(), m = b.size(), L;
for( L = 1 ; L <= n + m - 2 ; L <<= 1 );
rep( i, 0, L - 1 ) P[i] = Q[i] = 0;
rep( i, 0, n - 1 ) P[i] = a[i];
rep( i, 0, m - 1 ) Q[i] = b[i];
Basics :: DIF( P, L );
Basics :: DIF( Q, L );
rep( i, 0, L - 1 ) MulEq( P[i], Q[i] );
Basics :: DIT( P, L );
return std :: vector<int> ( P, P + n + m - 1 );
}
inline Poly Inv( const Poly &a ) {
int n = a.size();
rep( i, 0, n - 1 ) P[i] = a[i];
PolyInv :: PolyInv( P, P, n );
return Poly( P, P + n );
}
}
inline Frac operator + ( const Frac &a, const Frac &b ) {
using namespace PolyOperation;
return ( Frac ) { a.first * b.second + a.second * b.first, a.second * b.second };
}
Poly Multiply( const int &l, const int &r ) {
using namespace PolyOperation;
if( l > r ) return { 1 };
if( l == r ) return ply[l];
int mid = ( l + r ) >> 1;
return Multiply( l, mid ) * Multiply( mid + 1, r );
}
Frac Plus( const int &l, const int &r ) {
using namespace PolyOperation;
if( l > r ) return { Poly(), ( Poly ) { 1 } };
if( l == r ) return uni[l];
int mid = ( l + r ) >> 1;
return Plus( l, mid ) + Plus( mid + 1, r );
}
inline void Init( const int &n ) {
Basics :: NTTInit();
fac[0] = 1; rep( i, 1, n ) fac[i] = Mul( fac[i - 1], i );
ifac[n] = Inv( fac[n] ); per( i, n - 1, 0 ) ifac[i] = Mul( ifac[i + 1], i + 1 );
}
int main() {
using namespace PolyOperation;
Read( N ), Read( M ), Init( std :: max( N, M ) );
Poly E( M + 1, 0 ); E[0] = 1;
rep( i, 1, M ) E[i] = Mul( E[i - 1], N );
rep( i, 0, M ) MulEq( E[i], Mul( ifac[i], fac[N - 1] ) );
rep( i, 1, N - 1 ) ply[i] = { N, mod - i };
Poly fm( Multiply( 1, N - 1 ) );
rep( i, 0, N - 1 ) uni[i] = { ( Poly ) { fm[i] }, ( Poly ) { 1, ( mod - i ) % mod } };
Frac fr( Plus( 0, N - 1 ) );
fr.first.resize( M + 1 ), fr.second.resize( M + 1 );
fm = fr.first * Inv( fr.second ), fm.resize( M + 1 );
rep( i, 0, M ) MulEq( fm[i], ifac[i] );
fm = E * Inv( fm );
rep( i, 1, M ) Write( Mul( fm[i], fac[i] ) ), putchar( '\n' );
return 0;
}