【题解】P3600 随机数生成器(概率,期望)

【题解】P3600 随机数生成器

期望好题!

这道题从 11 号下午开始在机房写的,从十一号晚上调到十三号凌晨,终于调过去了!

第四道黑!(激动)

这道题让我自己是无论如何都想不出来的。这里完整详细叙述一遍思路,帮我自己梳理一遍。


更新日志:
upd on 2022.5.17:增加了易错点部分详解和加入了少量代码注释。


题目链接

P3600 随机数生成器

思路分析

从问题入手,问题是求这 \(q\) 个区间中每个区间最小值的最大值的期望。

根据期望的定义,那么有:

\[\frac {\sum \limits_{i=1}^x f_i\times i}{x^n} \]

其中 \(f_i\) 表示的是区间最大值恰好为 \(i\) 的方案数。

考虑对于每个区间来说,显然若这个区间 \(A\) 包含另一个区间 \(B\),那么这个区间 \(A\) 的最小值一定小于等于区间 \(B\)

由于我们要求的是区间最小值的最大值,那么可以看出,区间 \(A\) 对最后答案无影响。

那么我们就可以把这些区间 \(A\) 删掉。可以发现,如果把剩下的区间按左端点从小到大排序,那么右端点一定是严格单调递增的。

回到上面那个式子。

显然 \(x^n\) 可以直接快速幂求得,那么问题就转化为如何求 \(f_i\)

\(f_i\) 并不好求,根据常用套路,那么我们可以考虑将恰好转化为最多

定义 \(g_i\) 表示的是区间最大值 \(\le i\) 的方案数。

那么对于每一个 \(i\),利用类似前缀和差分的思想,有:

\[f_i=g_i-g_{i-1} \]

即:区间最大值恰好为 \(i\) 的方案数是区间最大值 \(\le i\) 的方案数减去区间最大值 \(\le i-1\) 的方案数。

考虑如何求 \(g_i\)

当区间最大值整体 \(\le i\),那么就说明:每个区间都至少有一个 \(\le i\) 的数。

那么有:

\[g_i=\sum \limits_{j=1}^n tp_j \times i^j \times (x-i)^{n-j} \]

\(tp_j\) 表示的是在 \(1-n\) 个区间里的所有数中,选出 \(j\)\(\le i\) 的点,每个区间至少包含一个点的方案数。

我们要计算 \(g_i\),可以枚举 \(j\),那么对于这 \(j\)\(\le i\) 的点,有 \(i^j\) 种情况;对于剩下 \(> i\)\(n-j\) 个点,有 \((x-i)^{n-j}\) 种情况;那么根据乘法原理,每个 \(j\)\(g_i\) 的贡献就是:

\[tp_j \times i^j \times (x-i)^{n-j} \]

最后累加求和即可。

现在问题转化成如何求 \(tp_i\)

考虑 DP。

定义 \(dp_{i,j}\) 表示的是前 \(i\) 个位置放了 \(j\) 个点,且第 \(i\) 个位置必须放点,覆盖所有的左端点 \(\le i\) 的区间的方案数。

我们定义 \(L_i\)\(R_i\) 分别表示覆盖位置 \(i\)最左最右区间编号。

特殊地,对于没有被任何区间覆盖的位置 \(i\)\(R_i\) 定义为右端点严格小于 \(i\) 的最后一个区间,\(L_i=R_i+1\)

那么有:

\[dp_{i,j}=\sum \limits_{R_k+1 \ge L_i}dp_{k,j} \]

枚举 \(i,j,k\),时间复杂度 \(O(n^3)\)

考虑优化。可以发现参与转移的 \(k\) 是一段连续的区间,那么前缀和优化即可。时间复杂度:\(O(n^2)\)

即:

\[dp_{i,j}=sum_{i-1,j-1}-sum_{k-1,j-1} \]

其中 \(sum_{i,j}\) 表示的是上一层的 \(\sum \limits_{k=1}^i dp_{k,j}\)

这里要注意的一点是:\(k-1\) 有可能为负,所以要特判 \(k=0\) 的情况。(我因为这个问题挂了一下午,而且本地测试挂了点还能过就很离谱

那么 \(g_j\) 就可以计算出:

\[g_j=\sum \limits_{R_i=q}dp_{i,j} \]

总的时间复杂度:\(O(n^2)\)

最后我们来梳理一下求解过程:

求期望 \(\rightarrow\) 求区间最大值恰好\(i\) 方案数 \(f_i \rightarrow\) 求区间最大值 \(\le i\) 的方案数 \(g_i \rightarrow\) 求每个区间都至少有一个 \(\le i\) 的点的方案数 \(tp_i \rightarrow\)\(dp_{i,j}\)

求解步骤

  1. 区间去重+排序:暴力枚举;

  2. 求出每个点 \(i\)\(R_i,L_i\):对于区间 \([1,x]\) 的每个点,暴力枚举每个区间计算,最后特判在最后一个区间右边的点;

  3. 枚举 \(i,j\),前缀和优化,计算出 \(dp_{i,j}\)

  4. 枚举 \(j\),根据 \(dp_{i,j}\),计算出 \(tp_j\)

  5. 枚举 \(i\),根据 \(tp_j\) 计算 \(g_i\)

  6. 根据 \(g_i\) 计算出 \(f_i\)

  7. 套用公式 \(\frac {\sum \limits_{i=1}^x f_i\times i}{x^n}\) 计算出最终答案。

易错点

这题调了不知道多久……提交了 15 次。。。

  • 关于区间去重+排序:

    1. 暴力枚举两个区间 \(i,j\) 是否包含时,要注意 \(i \ne j\)

    2. 当两个区间左右端点相同并且其中一个端点已经被标记为删除点时,不需要去重,即:if(a[i].l==a[j].l&&a[i].r==a[j].r&&(vis[i]||vis[j])) continue;

  • 关于求 \(R_i,L_i\)

    这里由于有很多不同的处理方法,所以先放出我的处理方法:

    void work()//求出 L[i] 和 R[i]
    {
        for(int i=1;i<=x;i++)
        {
            int j;
            for(j=1;j<=q;j++)
            {
                if(b[j].l<=i&&b[j].r>=i)
                {
                    if(!L[i])L[i]=j;
                    R[i]=j;
                }
                else if(b[j].l>i)
                {
                    if(!L[i])R[i]=j-1,L[i]=R[i]+1;
                    //若 i 不被任何区间所覆盖,则 R[i] 设为最后一个右端点比它小的区间编号,L[i]=R[i]+1; 
                    break;
                }
            }
            if(!L[i])R[i]=R[i-1],L[i]=R[i]+1;
        }
    } 
    

    我枚举 \(1-x\) 的每个点作为 \(i\),求出 \(L_i,R_i\)

    最易错的点就是,当一个点没有被任何区间覆盖时,有两种情况:

    一是在两个区间之间的“空隙”里,即:

    这种情况下,我们只需要在循环时特判一下:遇到一个左边界已经大于它的点 \(j\),那么 \(R_i=j-1,L_i=R_i+1\)

    二是在最后一个区间后面,即:

    这时不能仿照上面的处理方式了,我们需要在循环结束时判断一下这个点是否为最右端区间后面的点,具体做法就是判断 \(L_i\) 是否为 0(是否被更新过),若为 0,则说明这个点就是最右端区间后面的点,那么:R[i]=R[i-1],L[i]=R[i]+1;,这里要注意的一点是,这里的 \(R_i\) 赋值为 \(R_{i-1}\)我在这挂了至少半个小时。

  • 计算 \(dp_{i,j}\)

    因为 \(dp_{i,j}=\sum \limits_{R_k+1 \ge L_i}dp_{k,j}\),我们要使用前缀和优化就必须找到最小的 \(k\)

    所以我们通过 while(k<i-1&&R[k]+1<L[i])k++; 来找到 \(k\)。那么有一种可能 \(k=0\),那么在计算 \(sum_{k-1,j-1}\) 就会访问到负下标,会 RE(但是不知道为啥我的代码一直是 WA 并且本地编译毫无错误),所以需要特判一下,即:

    dp[i][j]=(sum[i-1][j-1]-(k?sum[k-1][j-1]:0)+mod)%mod;
    
  • 还是一样的套路,取模运算不要忘掉 ( +mod)%mod;

  • 整个 DP 部分用到很多 for 语句不要把边界弄混了(分清楚 \(x,n\),不要写反 \(i,j,k\)

实现代码

//luoguP3600
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<stack>
#define int long long
using namespace std;
const int mod=666623333;
const int maxn=2005;

struct Node{int l,r;}a[maxn],b[maxn];
int n,x,q;
int cnt,mx,ans,mi;
int vis[maxn],book[maxn],L[maxn],R[maxn]; 
//L[i] 和 R[i] 分别表示 i 可以覆盖到的编号最小和最大的区间 
int dp[maxn][maxn],tp[maxn],g[maxn],f[maxn],sum[maxn][maxn];
//dp[i][j] 表示前 i 个位置放了 j 个√,且第 i 个位置必须放√,覆盖所有的左端点 <=i 的区间的方案书。 
//tp[i] 表示的是在 [1,n] 内选出 i 个点,每个区间至少包含一个点的方案数。
//g[i] 表示的是区间最大值 <=i 时的方案数。
//f[i] 表示区间最大值恰好 =i 时的方案数。 

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}

int cmp(Node a,Node b){if(a.l==b.l)return a.r<b.r;return a.l<b.l;}

void init()//去重(将包含区间去除) 
{
    for(int i=1;i<=q;i++)
    {
        for(int j=1;j<=q;j++)
        {
            if(i==j)continue;//注意 i 不能等于 j。
            if(a[i].l==a[j].l&&a[i].r==a[j].r&&(vis[i]||vis[j])) continue; 
            if(a[i].l>=a[j].l&&a[i].r<=a[j].r)vis[j]++;
            else if(a[i].l<=a[j].l&&a[i].r>=a[j].r)vis[i]++;
        }
    }
    for(int i=1;i<=q;i++)
    {
        if(!vis[i])b[++cnt].l=a[i].l,b[cnt].r=a[i].r;
    }
    return ;
}

void work()//求出 L[i] 和 R[i]
{
    //cout<<mx<<endl;
    for(int i=1;i<=x;i++)
    {
        int j;
        for(j=1;j<=q;j++)
        {
            if(b[j].l<=i&&b[j].r>=i)
            {
                if(!L[i])L[i]=j;
                R[i]=j;
            }
            else if(b[j].l>i)
            {
                if(!L[i])R[i]=j-1,L[i]=R[i]+1;
                //若 i 不被任何区间所覆盖,则 R[i] 设为最后一个右端点比它小的区间编号,L[i]=R[i]+1; 
                break;
            }
        }
        if(!L[i])R[i]=R[i-1],L[i]=R[i]+1;//处理在最后一个区间后面的点的情况
    }
} 

int qpow(int a,int T)
{
    int ret=1;
    while(T)
    {
        if(T&1)(ret*=a)%=mod;
        (a*=a)%=mod;T>>=1;
    }
    return ret;
}

void out()
{
    for(int i=1;i<=x;i++)cout<<L[i]<<" "<<R[i]<<endl;
    /*for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=i;j++)cout<<dp[i][j]<<endl;
    }*/
    //for(int i=1;i<=n;i++)cout<<tp[i]<<endl;
}

void DP()
{
    //求 dp[i][j] 
    dp[0][0]=sum[0][0]=1;
    int k=0;
    for(int i=1;i<=n;i++)
    {
        while(k<i-1&&R[k]+1<L[i])k++;
        for(int j=1;j<=i;j++)
        {
            dp[i][j]=(sum[i-1][j-1]-(k?sum[k-1][j-1]:0)+mod)%mod;//防止负下标
        }
        for(int j=0;j<=i;j++)sum[i][j]=(sum[i-1][j]+dp[i][j])%mod;
    }
    //out(); 
    //求 tp[j] 
    for(int i=1;i<=n;i++)
    {
        if(R[i]!=cnt)continue;//注意这里是 R[i]!=cnt 而非 R[i]!=n
        for(int j=1;j<=n;j++)
        {
            (tp[j]+=dp[i][j])%=mod;
         } 
    }
    //out();
    //求 g[i] 
    for(int i=1;i<=x;i++)
    {
        for(int j=1;j<=n;j++)
        {
            (g[i]+=tp[j]*qpow(i,j)%mod*qpow(x-i,n-j)%mod)%=mod;
        }
    }
    //out();
    //求 f[i] 
    for(int i=1;i<=x;i++)f[i]=(g[i]-g[i-1]+mod)%mod;
    //out();
    int e=qpow(qpow(x,n),mod-2);//最后答案要除以 x^n 
    for(int i=1;i<=x;i++)(ans+=(i*f[i]%mod*e%mod))%=mod;
}

signed main()
{
    n=read();x=read();q=read();
    for(int i=1;i<=q;i++)
    {
        a[i].l=read();a[i].r=read();
    }
    sort(a+1,a+q+1,cmp);
    init();
    sort(b+1,b+cnt+1,cmp);//按 l 排序,保证 r 严格递增。
    work(); 
    DP();
    cout<<ans<<endl;
    return 0;
}
posted @ 2022-05-14 15:47  向日葵Reta  阅读(90)  评论(0编辑  收藏  举报