CF1366E 两个数组

1 CF1366E 两个数组

2 题目描述

时间限制 \(2s\) | 空间限制 \(256M\)

给出两个数组 \(𝑎_1,𝑎_2,…,𝑎_𝑛\)\(𝑏_1,𝑏_2,…,𝑏_𝑚\)。数组 \(b\) 是按照升序排好序。你需要把数组 \(a\) 分成 \(m\) 个连续子序列,使得对于每个 \(i\)\(1\)\(m\),第 \(i\) 个字串的最小值和 \(b_i\) 相等。注意每个元素严格的属于某个子序列,他们是这么生成的:\(a\) 的前几个元素构成第一个子序列,后面跟着的几个元素生成第二个子序列,以此类推。例如,如果 \(𝑎=[12,10,20,20,25,30]\), \(𝑏=[10,20,30]\),那么 \(a\) 的两个好的分割如下:

  • \([12,10,20],[20,25],[30]\)
  • \([12,10],[20,20,25],[30]\)

计算字符串 \(a\) 分割的方法的个数,结果对 \(998244353\) 取模。

数据范围:\(1≤𝑛,𝑚≤2\times 10^5\), \(1≤𝑎_𝑖,b_i≤10^9\), \(𝑏_𝑖<\)𝑏\(_𝑖\)\(_+\)\(_1\)

3 题解

我们先来思考答案为 \(0\) 的情况。如果 \(n < m\),那我们必然无法在 \(a\) 中找到最小值等于 \(b\) 中元素的 $m $ 段区间。如果我们在 \(b_1\) 这个值在 \(a\) 数组出现之前出现了一个比 \(b_1\) 还小的数,那么我们无论如何也得把这个数并到第一个区间内,而第一个区间的最小值也就无论如何都不是 \(b_1\) 了。如果在 \(b_i\)\(b_{i+1}\) 之间存在一个比 \(b_i\) 小的数,那么我们必定无法将第 \(i\) 个区间和第 \(i+1\) 个区间的最小值都处理好:\(b_i < b_{i+1}\),且必定有一个区间包含这个较小的数。同时,如果在 \(b\) 中的某个数并不在 \(a\) 中出现过,我们肯定无法将某个区间的最小值变为这个 \(b\) 数组中的数。上述的这些情况答案都为 \(0\)

我们考虑 \(b_i\)\(b_{i+1}\) 之间的可能方案。我们可以通过将 \(b_{i+1}\) 或者在 \(b_i\)\(b_{i+1}\) 中出现过的与 \(b_{i+1}\) 相邻的大于 \(b_{i+1}\) 的数分给第 \(i+1\) 个区间或者第 \(i\) 个区间来增加方案数。我们最后将得到的 \(b_i\)\(b_{i+1}\) 之间的可能方案数乘上之前所有可能方案数之积,得到的就是到前 \(i+1\) 个数中可能的总方案数。

4 代码(空格警告):

#include <iostream>
using namespace std;
const int N = 2e5+10;
const int mod = 998244353;
#define int long long
int n, m, pos, ans, cnt, cnt2, flag;
int a[N], b[N], s[N], len[N];
signed main()
{
    cin >> n >> m;
    if (n < m)
    {
        cout << 0;
        return 0;
    }
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= m; i++) cin >> b[i];
    for (int i = 1; i <= n; i++)
    {
        if (a[i] < b[1])
        {
            cout << 0;
            return 0;
        }
        if (a[i] == b[1])
        {
            pos = i;
            break;
        }
    }
    cnt = 2;
    if (pos == 0)
    {
        cout << 0;
        return 0;
    }
    for (int i = pos+1; i <= n; i++)
    {
        if (a[i] == b[cnt])
        {
            if (a[i+1] != b[cnt]) cnt++;
        }
        else
        {
            if (a[i] < b[cnt-1])
            {
                cout << 0;
                return 0;
            }
        }
    }
    if (cnt != m+1)
    {
        cout << 0;
        return 0;
    }
    ans = 1;
    cnt = 1;
    for (int i = 1; i <= n; i++)
    {
        if (a[i] == b[cnt] && !flag)
        {
            len[cnt]++;
            flag = 1;
            s[cnt] = i;
        }
        else if (a[i] == b[cnt])
        {
            len[cnt]++;
        }
        else
        {
            flag = 0;
            if (a[i] == b[cnt+1])
            {
                cnt++;
                len[cnt]++;
                flag = 1;
                s[cnt] = i;
            }
        }
    }
    cnt = 2;
    cnt2 = 0;
    for (int i = pos+1; i <= n; i++)
    {
        if (i == s[cnt])
        {
            ans = (ans * (cnt2 + len[cnt]) % mod) % mod;
            i += len[cnt] - 1;
            cnt2 = 0;
            cnt++;
        }
        else
        {
            if (a[i] > b[cnt])
            {
                cnt2++;
            }
            if (a[i] < b[cnt])
            {
                cnt2 = 0;
            }
        }
    }
    cout << ans;
    return 0;
}

欢迎关注我的公众号

posted @ 2021-03-14 15:53  David24  阅读(74)  评论(0编辑  收藏  举报