Live2D

Solution -「多校联训」轮回

\(\mathcal{Description}\)

  有 \(n\) 个黑盒,第 \(i\) 个黑盒可以让输入变量以 \(p_i\) 的概率保持不变,以 \(\frac{1-p_i}2\) 的概率加一或减一。称一次从 \(i\) 开始的游戏为:初始变量的值为 \(0\),从 \(i\) 开始,将变量依次输入 \(i,i+1,\dots,n,1,2,\dots,n,\dots\),直到变量的绝对值大于 \(m\)。处理 \(q\) 次操作:

  1. 求此时一次从某个 \(i\) 开始的游戏期望经过多少个黑盒。答案模 \((10^9+7)\)
  2. 修改某个 \(p_i\) 的值。

  \(n,q\le10^5\)\(m\le5\)

\(\mathcal{Solution}\)

  DP 是平凡的:令 \(f_{i,j}\) 表示绝对值为 \(j\) 的变量将要输入第 \(i\) 个黑盒时,期望还能经过多少黑盒。那么

\[f_{i,j}=\begin{cases} 1+p_if_{i+1,j}+(1-p_i)f_{i+1,j+1}&j=0\\ 1+p_if_{i+1,j}+\frac{1-p_i}{2}(f_{i+1,j+1}+f_{i+1,j-1})&j>0 \end{cases}. \]

  紧接着,你看我久违地把 DP 状态写成下标的形式就知道,应用整体思想,令

\[F_i=\begin{bmatrix}f_{0,0}\\ \vdots\\ f_{m,0}\\ 1 \end{bmatrix}, \]

据此不难得出

\[F_i=M_iF_{i+1}, \]

其中 \(M_i\) 是一个仅与 \(p_i\) 有关的转移矩阵。鉴于答案要求 \(f_{i,0}\),我们只需要关心某个特定的 \(F_i\) 的值。从方程角度考虑,发现

\[\begin{aligned} F_i&=M_iF_{i+1}\\ &=M_iM_{i+1}\cdots M_{n}F_1\\ &=M_iM_{i+1}\cdots M_{n}M_1M_2\cdots M_{i-1}F_i \end{aligned}. \]

所以说用线段树维护 \(M\) 的区间积,快速求出 \(F_i=AF_i\) 中的 \(A\),然后 \(\mathcal O(m^3)\) 消元就行。最终最劣复杂度为 \(\mathcal O(qm^3\log n)\)

\(\mathcal{Code}\)

/* Clearink */

#include <cstdio>
#include <cassert>
#include <cstring>
#include <iostream>

#define rep( i, l, r ) for ( int i = l, rep##i = r; i <= rep##i; ++i )
#define per( i, r, l ) for ( int i = r, per##i = l; i >= per##i; --i )

inline int rint() {
	int x = 0, f = 1, s = getchar();
	for ( ; s < '0' || '9' < s; s = getchar() ) f = s == '-' ? -f : f;
	for ( ; '0' <= s && s <= '9'; s = getchar() ) x = x * 10 + ( s ^ '0' );
	return x * f;
}

template<typename Tp>
inline void wint( Tp x ) {
	if ( x < 0 ) putchar( '-' ), x = -x;
	if ( 9 < x ) wint( x / 10 );
	putchar( x % 10 ^ '0' );
}

const int MAXN = 1e5, MAXK = 5, MOD = 1e9 + 7, INV2 = MOD + 1 >> 1;
int n, m /* i.e. k */, q;

inline int mul( const long long a, const int b ) { return a * b % MOD; }
inline int sub( int a, const int b ) { return ( a -= b ) < 0 ? a + MOD : a; }
inline void subeq( int& a, const int b ) { ( a -= b ) < 0 && ( a += MOD ); }
inline int add( int a, const int b ) { return ( a += b ) < MOD ? a : a - MOD; }
inline void addeq( int& a, const int b ) { ( a += b ) >= MOD && ( a -= MOD ); }
inline int mpow( int a, int b ) {
    int ret = 1;
    for ( ; b; a = mul( a, a ), b >>= 1 ) ret = mul( ret, b & 1 ? a : 1 );
    return ret;
}

struct Matrix {
    static const int MAXS = MAXK + 1;
    int mat[MAXS + 1][MAXS + 1];

    Matrix(): mat{} {} // !!!

    inline int* operator [] ( const int k ) { return mat[k]; }

    friend inline Matrix operator * ( const Matrix& u, const Matrix& v ) {
        Matrix ret;
        rep ( i, 0, MAXS ) rep ( k, 0, MAXS ) rep ( j, 0, MAXS ) {
            addeq( ret[i][j], mul( u.mat[i][k], v.mat[k][j] ) );
        }
        return ret;
    }
};

inline void setTMat( Matrix& u, const int p ) {
    u = Matrix();
    u[0][0] = p;
    rep ( i, 0, m + 1 ) u[m + 1][i] = 0, u[i][m + 1] = 1;
    if ( !m ) return ;
    u[0][1] = sub( 1, p );
    
    int cp = mul( INV2, sub( 1, p ) );
    rep ( i, 1, m ) {
        u[i][i] = p, u[i][i - 1] = cp;
        if ( i + 1 <= m ) u[i][i + 1] = cp;
    }
}

struct SegmentTree {
    Matrix sum[MAXN << 2];

    inline void build( const int u, const int l, const int r ) {
        if ( l == r ) {
            int a = rint();
            setTMat( sum[u], mul( a, mpow( rint(), MOD - 2 ) ) );
            return ;
        }
        int mid = l + r >> 1;
        build( u << 1, l, mid ), build( u << 1 | 1, mid + 1, r );
        sum[u] = sum[u << 1] * sum[u << 1 | 1];
    }

    inline void modify( const int u, const int l, const int r,
      const int x, const int p ) {
        if ( l == r ) return setTMat( sum[u], p );
        int mid = l + r >> 1;
        if ( x <= mid ) modify( u << 1, l, mid, x, p );
        else modify( u << 1 | 1, mid + 1, r, x, p );
        sum[u] = sum[u << 1] * sum[u << 1 | 1];
    }

    inline Matrix query( const int u, const int l, const int r,
      const int ql, const int qr ) {
        if ( ql <= l && r <= qr ) return sum[u];
        int mid = l + r >> 1;
        if ( qr <= mid ) return query( u << 1, l, mid, ql, qr );
        if ( mid < ql ) return query( u << 1 | 1, mid + 1, r, ql, qr );
        return query( u << 1, l, mid, ql, qr )
          * query( u << 1 | 1, mid + 1, r, ql, qr );
    }
} sgt;

inline void gauss( const int s, int a[MAXK + 1][MAXK + 2], int* x ) {
    rep ( i, 0, s ) {
        int p;
        rep ( j, i, s ) if ( a[j][i] ) { p = j; break; }
        if ( i != p ) std::swap( a[i], a[p] );
        assert( a[i][i] );

        int iv = mpow( a[i][i], MOD - 2 );
        rep ( j, i + 1, s ) {
            int t = mul( a[j][i], iv );
            rep ( k, i, s + 1 ) subeq( a[j][k], mul( t, a[i][k] ) );
        }
    }

    per ( i, s, 0 ) {
        x[i] = mul( a[i][s + 1], mpow( a[i][i], MOD - 2 ) );
        per ( j, i - 1, 0 ) subeq( a[j][s + 1], mul( x[i], a[j][i] ) );
    }
}

inline int solve( Matrix& u ) {
    int a[MAXK + 1][MAXK + 2], x[MAXK + 1];
    memset( a, 0, sizeof a );

    rep ( i, 0, m ) {
        a[i][i] = MOD - 1;
        rep ( j, 0, m ) addeq( a[i][j], u[i][j] );
        a[i][m + 1] = sub( 0, u[i][m + 1] );
    }

    gauss( m, a, x );
    return x[0];
}

int main() {
    freopen( "samsara.in", "r", stdin );
    freopen( "samsara.out", "w", stdout );

    n = rint(), m = rint(), q = rint();
    sgt.build( 1, 1, n );

    for ( int op, i, a, b; q--; ) {
        op = rint(), i = rint();
        if ( op == 1 ) {
            Matrix tmat;
            if ( i == 1 ) tmat = sgt.query( 1, 1, n, 1, n );
            else {
                tmat = sgt.query( 1, 1, n, i, n )
                  * sgt.query( 1, 1, n, 1, i - 1 );
            }
            wint( solve( tmat ) ), putchar( '\n' );
        } else {
            a = rint(), b = rint();
            sgt.modify( 1, 1, n, i, mul( a, mpow( b, MOD - 2 ) ) );
        }
    }
	return 0;
}

posted @ 2021-06-18 22:24  Rainybunny  阅读(83)  评论(0编辑  收藏  举报