抽牌游戏

题意简述

一副 \(n+m\) 张牌的扑克牌,\(m\) 张 joker。初始牌堆里有这样一副牌。随机抽一张牌拿走,如果是 joker,将所有牌放回牌堆并打乱。问你抽到过所有 \(n\) 张正常牌的期望抽牌次数是多少?对 \(M = 19260817\) 取模。

\(n \leq 10^8\)\(m \leq 10^{18}\)

题目分析

概率期望类题目,考虑 DP,并且期望 DP 套路是从后往前递推。

显然应该可以状压 DP,但是其非常不利于后续优化。所以尝试使用线性 DP。

DP 记录值显然是当前状态到终态需要的期望抽牌次数。状态有哪些呢?牌堆、抽到过哪些牌了。如果朴素记录就是状压了,但是我们发现,\(n\) 张普通牌和 \(m\) 张 joker 并没有本质差别,是等价的,无非前者需要区分有没有被抽到过。

不妨使用 \(f_{i,j}\) 表示当前已经抽到过 \(i\) 张牌,牌堆里有 \(j\) 张牌,到终态的期望抽牌次数。我们需要明确的是,哪些状态是合法的,显然需要 \(j\) 中包含 \(m\) 张牌,和 \(n-i\)\(i\) 中没有的普通牌,即 \(j \geq n+m-i\)。对于 \(i+j \geq n+m\) 的情况表示 \(i+j-n-m\) 张牌已经抽到过了,但后来被重新加入牌堆中。

明确好状态,就可以转移了。我们有 \(\frac{n-i}{j}\) 的概率,抽到一张全新的牌,转移到 \(f_{i+1,j-1}\);有 \(\frac{i+j-n-m}{j}\) 的概率,抽到一张抽到过的牌,转移到 \(f_{i,j-1}\);有 \(\frac{m}{j}\) 的概率,抽到 joker,转移到 \(f_{i,n+m}\)。验证一下,\(\frac{n-i}{j}+\frac{i+j-n-m}{j}+\frac{m}{j}=1\),没有问题。

\[\Large f_{i,j}={\textstyle \frac{n-i}{j}}f_{i+1,j-1}+{\textstyle \frac{i+j-n-m}{j}}f_{i,j-1}+{\textstyle \frac{m}{j}}f_{i,n+m}+1 \]

边界 \(f_{n,j}=0\),答案 \(f_{0,n+m}\)。这不好递推,怎么办呢?

我们可以把它看做二维平面内的随机游走,向左下、左、行末行走。这个往行末行走就很经典。我们可以设 \(f_{i,j}=k_{i,j}\cdot f_{i,n+m}+b_{i,j}\),从 \(j=n+m-i\) 推到 \(j=n+m\),就是一个方程,方程解出来,\(f_{i}\) 就解出来了。具体可以见文末代码。

上述 DP 时空复杂度 \(\Theta(n^2)\),需要优化。经过打表发现,\(f_{i}\)\(j\) 为等差数列。

【404 not found】

作者太菜了,还不会证。

我们设 \(f_{i,j}=\lambda_i+\mu_i\cdot(n+m-j)\),我们只需要任意两项 \(j\),就能确定 \(\lambda_i, \mu_i\),也就确定了 \(f_{i}\),为了方便起见,取末两项解方程。

\[\begin{aligned} &\Large\left\{\begin{aligned} & f_{i,n+m} = {\textstyle \frac{n-i}{n+m}}f_{i+1,n+m-1}+{\textstyle\frac{i}{n+m}}f_{i,n+m-1}+{\textstyle\frac{m}{n+m}}f_{i,n+m}+1 \\ & f_{i,n+m-1}= {\textstyle \frac{n-i}{n+m-1}}f_{i+1,n+m-2}+{\textstyle\frac{i-1}{n+m-1}}f_{i,n+m-2}+{\textstyle\frac{m}{n+m-1}}f_{i,n+m}+1 \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i = {\textstyle \frac{n-i}{n+m}}(\lambda_{i+1}+\mu_{i+1})+{\textstyle\frac{i}{n+m}}(\lambda_i+\mu_i)+{\textstyle\frac{m}{n+m}}\lambda_i+1 \\ & \lambda_i+\mu_i= {\textstyle \frac{n-i}{n+m-1}}(\lambda_{i+1}+2\mu_{i+1})+{\textstyle\frac{i-1}{n+m-1}}(\lambda_i+2\mu_i)+{\textstyle\frac{m}{n+m-1}}\lambda_i+1 \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i={\textstyle\frac{i}{n-i}}\mu_i+\lambda_{i+1}+\mu_{i+1}+{\textstyle\frac{n+m}{n-i}} \\ & {\normalsize(n+m-2i+1)}\mu_i={\normalsize(i-n)}\lambda_i+{\normalsize n+m-1}+{\normalsize(n-i)}(\lambda_{i+1}+2\mu_{i+1}) \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i={\normalsize\frac{i\cdot\mu_i}{n-i}+\lambda_{i+1}+\mu_{i+1}+\frac{n+m}{n-i}} \\ & \mu_i={\normalsize\frac{(n-i)\cdot\mu_{i+1}-1}{n+m-i+1}} \end{aligned}\right. \\\\ \end{aligned} \]

于是可以 \(\mathcal{O}(n \log M)\),若 \(m = \mathcal{O}(n)\),则可以完全线性 \(\mathcal{O}(n)\)。边界 \(\lambda_n=\mu_n=0\),答案 \(\lambda_0\)

代码

取模板子
namespace Mod_Int_Class {
    template <typename T, typename _Tp>
    constexpr bool in_range(_Tp val) {
        return std::numeric_limits<T>::min() <= val && val <= std::numeric_limits<T>::max();
    }
    
    template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
    static constexpr inline bool is_prime(_Tp val) {
        if (val < 2) return false;
        for (_Tp i = 2; i * i <= val; ++i)
            if (val % i == 0)
                return false;
        return true;
    }
    
    template <auto _mod = 19260817, typename T = int, typename S = long long>
    class Mod_Int {
        static_assert(in_range<T>(_mod), "mod must in the range of type T.");
        static_assert(std::is_integral<T>::value, "type T must be an integer.");
        static_assert(std::is_integral<S>::value, "type S must be an integer.");
        public:
            constexpr Mod_Int() noexcept = default;
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr Mod_Int(_Tp v) noexcept: val(0) {
                if (0 <= S(v) && S(v) < mod) val = v;
                else val = (S(v) % mod + mod) % mod;
            }
            
            constexpr T const& raw() const {
                return this -> val;
            }
            static constexpr T mod = _mod;
            
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend Mod_Int pow(Mod_Int a, _Tp p) {
                return a ^ p;
            }
            constexpr friend Mod_Int sub(Mod_Int a, Mod_Int b) {
                return a - b;
            }
            constexpr friend Mod_Int& tosub(Mod_Int& a, Mod_Int b) {
                return a -= b;
            }
            
            constexpr friend Mod_Int add(Mod_Int a) { return a; }
            template <typename... args_t>
            constexpr friend Mod_Int add(Mod_Int a, args_t... args) {
                return a + add(args...);
            }
            constexpr friend Mod_Int mul(Mod_Int a) { return a; }
            template <typename... args_t>
            constexpr friend Mod_Int mul(Mod_Int a, args_t... args) {
                return a * mul(args...);
            }
            template <typename... args_t>
            constexpr friend Mod_Int& toadd(Mod_Int& a, args_t... b) {
                return a = add(a, b...);
            }
            template <typename... args_t>
            constexpr friend Mod_Int& tomul(Mod_Int& a, args_t... b) {
                return a = mul(a, b...);
            }
            
            template <T __mod = mod, typename = std::enable_if_t<is_prime(__mod)>>
            static constexpr inline T inv(T a) {
                assert(a != 0);
                return _pow(a, mod - 2);
            }
            
            constexpr Mod_Int& operator + () const {
                return *this;
            }
            constexpr Mod_Int operator - () const {
                return _sub(0, val);
            }
            constexpr Mod_Int inv() const {
                return inv(val);
            }
            
            constexpr friend inline Mod_Int operator + (Mod_Int a, Mod_Int b) {
                return _add(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator - (Mod_Int a, Mod_Int b) {
                return _sub(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator * (Mod_Int a, Mod_Int b) {
                return _mul(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator / (Mod_Int a, Mod_Int b) {
                return _mul(a.val, inv(b.val));
            }
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend inline Mod_Int operator ^ (Mod_Int a, _Tp p) {
                return _pow(a.val, p);
            }
            
            constexpr friend inline Mod_Int& operator += (Mod_Int& a, Mod_Int b) {
                return a = _add(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator -= (Mod_Int& a, Mod_Int b) {
                return a = _sub(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator *= (Mod_Int& a, Mod_Int b) {
                return a = _mul(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator /= (Mod_Int& a, Mod_Int b) {
                return a = _mul(a.val, inv(b.val));
            }
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend inline Mod_Int& operator ^= (Mod_Int& a, _Tp p) {
                return a = _pow(a.val, p);
            }
            
            constexpr friend inline bool operator == (Mod_Int a, Mod_Int b) {
                return a.val == b.val;
            }
            constexpr friend inline bool operator != (Mod_Int a, Mod_Int b) {
                return a.val != b.val;
            }
			
			constexpr Mod_Int& operator ++ () {
				this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
				return *this;
			}
			constexpr Mod_Int& operator -- () {
				this -> val == 0 ? this -> val = mod - 1 : --this -> val;
				return *this;
			}
			constexpr Mod_Int operator ++ (int) {
				Mod_Int res = *this;
				this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
				return res;
			}
			constexpr Mod_Int operator -- (int) {
				Mod_Int res = *this;
				this -> val == 0 ? this -> val = mod - 1 : --this -> val;
				return res;
			}
			
			friend std::istream& operator >> (std::istream& is, Mod_Int<mod, T, S>& x) {
				T ipt;
				return is >> ipt, x = ipt, is;
			}
			friend std::ostream& operator << (std::ostream& os, Mod_Int<mod, T, S> x) {
				return os << x.val;
			}
        protected:
            T val;
            
            static constexpr inline T _add(T a, T b) {
                return a >= mod - b ? a + b - mod : a + b;
            }
            static constexpr inline T _sub(T a, T b) {
                return a < b ? a - b + mod : a - b;
            }
            static constexpr inline T _mul(T a, T b) {
                return static_cast<S>(a) * b % mod;
            }
            
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            static constexpr inline T _pow(T a, _Tp p) {
                T res = 1;
                for (; p; p >>= 1, a = _mul(a, a))
                    if (p & 1) res = _mul(res, a);
                return res;
            }
    };
	
    using mint = Mod_Int<>;
	using mod_t = mint;
	
    constexpr mint operator ""_m (unsigned long long x) {
        return mint(x);
    }
    constexpr mint operator ""_mod (unsigned long long x) {
        return mint(x);
    }
}

using namespace Mod_Int_Class;
$\mathcal{O}(n^2)$ 部分分 & 打表
#include <cstdio>
#include <iostream>
#include <limits>
#include <cassert>
#include <vector>
using namespace std;

int n, m;

namespace $1 {
    bool check() {
        return n <= 1000;
    }
    
    void solve() {
        vector<vector<mint>> f(n + 1, vector<mint>(n + 1));
        for (int i = n - 1; i >= 0; --i) {
            vector<mint> k(i + 1), b(i + 1);
            k[0] = 1_mod * m / (n + m - i);
            b[0] = 1_mod * (n - i) / (n + m - i) * f[i + 1][0] + 1;
            for (int j = 1; j <= i; ++j) {
                k[j] = 1_mod * m / (n + m - i + j)
                     + 1_mod * j / (n + m - i + j) * k[j - 1];
                b[j] = 1_mod * (n - i) / (n + m - i + j) * f[i + 1][j]
                     + 1_mod * j / (n + m - i + j) * b[j - 1] + 1;
            }
            f[i][i] = b[i] / (1 - k[i]);
            for (int j = 0; j < i; ++j)
                f[i][j] = k[j] * f[i][i] + b[j];
        }
        printf("%d\n", f[0][0].raw());
        
        vector<vector<double>> g(n + 1, vector<double>(n + 1));
        for (int i = n - 1; i >= 0; --i) {
            vector<double> k(i + 1), b(i + 1);
            k[0] = 1. * m / (n + m - i);
            b[0] = 1. * (n - i) / (n + m - i) * g[i + 1][0] + 1;
            for (int j = 1; j <= i; ++j) {
                k[j] = 1. * m / (n + m - i + j)
                     + 1. * j / (n + m - i + j) * k[j - 1];
                b[j] = 1. * (n - i) / (n + m - i + j) * g[i + 1][j]
                     + 1. * j / (n + m - i + j) * b[j - 1] + 1;
            }
            g[i][i] = b[i] / (1 - k[i]);
            for (int j = 0; j < i; ++j)
                g[i][j] = k[j] * g[i][i] + b[j];
            
            for (int j = 0; j <= i; ++j)
                printf("%.10lf ", g[i][j]);
            puts("");
            // for (int j = 1; j <= i; ++j)
            //     printf("%.10lf ", g[i][j] - g[i][j - 1]);
            // puts("");
        }
    }
}

signed main() {
    #ifndef XuYueming
    freopen("toad.in", "r", stdin);
    freopen("toad.out", "w", stdout);
    #endif
    scanf("%d%d", &n, &m);
    if ($1::check()) return $1::solve(), 0;
    $yzh::solve();
    return 0;
}
$\mathcal{O}(n \log M)$ 正解
namespace $yzh {
    const int N = 1000010;
    
    mint lambda[N], mu[N];
    
    void solve() {
        lambda[n] = mu[n] = 0;
        for (int i = n - 1; i >= 0; --i) {
            mu[i] = ((n - i) * mu[i + 1] - 1) / (n + m - i + 1);
            lambda[i] = i * mu[i] / (n - i) + lambda[i + 1] + mu[i + 1] + 1_mod * (n + m) / (n - i);
        }
        printf("%d", lambda[0].raw());
    }
}
卡常后
#pragma GCC optimize("Ofast", "inline", "fast-math", "unroll-loops")
#include <cstdio>

const int N = 1000010, mod = 19260817;

int n, m, lambda, mu, Inv[N << 1];
inline int add(int a, int b) { return a >= mod - b ? a + b - mod : a + b; }

signed main() {
    freopen("toad.in", "r", stdin);
    freopen("toad.out", "w", stdout);
    scanf("%d%d", &n, &m), Inv[1] = 1;
    for (register int i = 2, *I = Inv + 2; i <= n + m + 1; ++i, ++I)
        *I = 1ll * (mod - Inv[mod % i]) * (mod / i) % mod;
    for (register int i = 1; i <= n; ++i) {
        int t = mu;
        mu = 1ll * add(1ll * i * mu % mod, mod - 1) * Inv[m + i + 1] % mod;
        lambda = add(1ll * add(n + m, 1ll * (n - i) * mu % mod) * Inv[i] % mod, add(lambda, t));
    }
    printf("%d", lambda);
    return 0;
}
posted @ 2024-11-14 11:06  XuYueming  阅读(3)  评论(0编辑  收藏  举报