Codeforces Round #778 (Div. 1 + Div. 2)

不会真的有人 \(\rm D\) 罚坐 \(\rm 1h\) 想不到分解质因数吧。话说 CF 题解怎么又有锅(

A. Maximum Cake Tastiness

\(t\) 组数据,每组给出一个长为 \(n\) 的序列 \(a\),可以进行以下操作最多一次:

  • 选择一段 \([l,r]\),并反转 \(a_{l..r}\)

最大化操作后的序列相邻两数和的最大值。(\(1\le t\le 50,2\le n\le 10^3,1\le a_i\le 10^9\))

注意到,通过这次操作,我们能让任意两个我们想相邻的数相邻。具体来讲,如果想让 \(a_{k_1},a_{k_2}(k_1<k_2)\) 相邻,只需反转 \([k_1+1,k_2]\)。所以原问题变为原序列任意两个数和的最大值,std::sort 一下找最大的俩即可。时间复杂度 \(\mathcal{O}(tn\log n)\)

#include <cstdio>
#include <algorithm>
const int N = 1e5 + 10; int n, m, a[N];
int main()
{
    int qwq; scanf("%d", &qwq);
    while (qwq--)
    {
        scanf("%d", &n);
        for (int i = 1; i <= n; ++i) scanf("%d", a + i);
        std::sort(a + 1, a + n + 1);
        printf("%d\n", a[n] + a[n - 1]);
    }
    return 0;
}

B. Prefix Removals

\(t\) 组数据,每组给出一个字符串 \(s\),不断重复以下操作直到无法执行:

  • 定义 \(x\) 为其最长的,满足在 \(s\) 中出现了至少两次的前缀,从 \(s\) 中删去 \(x\)

求操作结束后的 \(s\)。(\(1\le t\le 10^4,1\le \sum |s|\le 2\times 10^5\))

注意到,如果不考虑 \(x\) 是最长的限制,答案是不会变的。所以我们的问题变为了,能不能删去首字符,一个字符能删去,当且仅当它出现了至少两次。直接开个桶动态记录每个字符出现了几次即可,时间复杂度 \(\mathcal{O}(t|s|)\)

#include <cstdio>
#include <cstring>
const int N = 3e5 + 10; int n, m, bkt[26]; char s[N];
inline void clr() { memset(bkt, 0, sizeof (bkt)); }
int main()
{
    int qwq; scanf("%d", &qwq);
    while (qwq--)
    {
        clr(); scanf("%s", s + 1); n = strlen(s + 1);
        for (int i = 1; i <= n; ++i) ++bkt[s[i] - 'a'];
        int p = 1;
        while (bkt[s[p] - 'a'] >= 2) --bkt[s[p] - 'a'], ++p;
        printf("%s\n", s + p);
    }
    return 0;
}

C. Alice and the Cake

\(t\) 组数据,每组有一个数 \(X\)(不给出 \(X\)),现在进行以下操作恰好 \(n-1\) 次:

  • 任意选择一个数 \(w(w\ge 2)\),将它分割为 \(\lfloor\frac{w}{2}\rfloor,\lceil\frac{w}{2}\rceil\)

乱序给出最终长为 \(n\) 的序列 \(a\),判断是否存在一个操作序列,满足能从 \(X\) 生成 \(a\)。(\(1\le t\le 10^4,1\le \sum n\le 2\times10^5,n\ge 1,1\le a_i\le 10^9\))

注意到 \(X\) 不给出我们也能求出来,它就等于 \(\sum a_i\)。知道 \(X\) 后,其实我们就能知道把 \(X\) 分割的整个过程了,其中最终序列一定是这个过程中的某个状态。

考虑用一个队列,表示当前需要的数,和一个桶,表示每个数还剩多少,初始时队列里只有 \(X\)。每次取出队头,如果这个数在最终序列里出现了,说明它已经被表示出来了,没有分割,把这个数从桶中删去一个即可。而如果这个数没有出现,说明它被分割了,将它分割后的值加入队列即可。注意到这个过程将会进行 \(n-1\) 次,如果 \(n-1\) 次后,队列和桶都是空的,说明东西恰好用完,最终序列合法。反之不合法。时间复杂度 \(\mathcal{O}(n\log n)\)

#include <map>
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
const int N = 3e5 + 10; typedef long long ll; int n, m, a[N];
std::map<ll, int> bkt;
inline void clr() { bkt.clear(); }
int main()
{
    int qwq; scanf("%d", &qwq);
    while (qwq--)
    {
        clr(); scanf("%d", &n); ll s = 0;
        for (int i = 1; i <= n; ++i) scanf("%d", a + i), ++bkt[a[i]], s += a[i];
        if (n == 1) { puts("YES"); continue; }
        std::queue<ll> q; q.push(s);
        for (int i = 1; i < n; ++i)
        {
            ll u = q.front(); q.pop();
            ll x = u / 2, y = u - x;
            if (!bkt.count(x)) q.push(x);
            else if (--bkt[x] == 0) bkt.erase(x);
            if (!bkt.count(y)) q.push(y);
            else if (--bkt[y] == 0) bkt.erase(y);
        }
        puts(bkt.empty() && q.empty() ? "YES" : "NO");
    }
    return 0;
}

D. Potion Brewing Class

\(t\) 组数据,每组有一个长为 \(n\) 的序列 \(r\),给出 \(n-1\) 个形如 \((x,y,a,b)\) 的比例关系,表示:

\[r_x:r_y=a:b \]

求满足以上条件的 整数 序列中,\(\sum r_i\) 最小的序列的和,答案对 \(998,244,353\) 取模。(\(1\le t\le 10^4,1\le \sum n\le 2\times 10^5,n\ge 2,1\le x,y,a,b\le n\),保证给出的关系足以确定 \(r_i\))

本来以为这题巨好写就没管 \(\rm AB\) 的罚时。然后发现自己小学奥数不会了。

考虑这个关系足以确定 \(r_i\),所以如果把 \((x,y)\) 看成边,原图是连通的,由此,原关系形成一棵树。而对于这棵树,一旦根(显然原题是无根树,不过我们可以随便钦定一个)的值确定为正确的值,所有的点的值都确定了,考虑从上到下确定即可。既然这样,我们可以假设根的值为 \(x\),然后往下确定。这样我们会得到某个点的值形如:

\[\dfrac{px}{q} \]

不妨设:

\[p_i=\prod_{i}\mathbf{p_i}^{t_i},q_i=\prod_{i}\mathbf{p_i}^{c_i},\mathbf{p_i}\in\mathbb{P} \]

\(x\) 应该至少为:

\[\prod_{i}\mathbf{p_i}^{\max(c_i-t_i,0)} \]

\(x\) 的最小值也应为这个值,所以我们现在的任务就变为了对于每个质数求出 \(\max(c_i-t_i,0)\)

首先对值域内的数分解质因数,可以用埃筛或者欧拉筛,均在 \(\mathcal{O}(n\log n)\) 的时间复杂度内求解。然后可以做到一遍 \(\rm dfs\) 求出每个 \(\mathbf{p_i}\)\(\max(c_i-t_i,0)\)。注意到每个值只会对 \(\mathcal{O}(\log n)\)\(\mathbf{p_i}\) 有影响,所以复杂度是对的。而且我们可以用离开该结点时撤销该结点的影响的方法,来把空间从 \(\mathcal{O}(\frac{n^2}{\log n})\),压到 \(\mathcal{O}(\frac{n}{\log n})\)

然后就能得到 \(x\) 的值。之后再做一遍 \(\rm dfs\) 就能得到每个点的值。时间复杂度 \(\mathcal{O}(n\log n)\)

#include <cstdio>
#include <vector>
const int N = 2e5 + 10, mod = 998244353; typedef long long ll;
struct node
{ 
    int x, y; 
    node(int x = 0, int y = 0) : x(x), y(y) { }
}; std::vector<std::pair<int, node>> G[N];
std::vector<int> V[N]; int p[N], vis[N], f[N], g[N], val[N], tp;
inline int ksm(int a, int b)
{
    int ret = 1;
    while (b)
    {
        if (b & 1) ret = (ll)ret * a % mod;
        a = (ll)a * a % mod; b >>= 1;
    }
    return ret;
}
inline void insert(int x, int y, int a, int b) 
{ 
    G[x].emplace_back(y, node(a, b));
    G[y].emplace_back(x, node(b, a));
}
inline void sieve(int n)
{
    for (int i = 2; i <= n; ++i)
    {
        if (!vis[i]) V[i].push_back(i), p[++tp] = i;
        for (int j = 1; j <= tp && i * p[j] <= n; ++j)
        {
            V[i * p[j]] = V[i]; V[i * p[j]].push_back(p[j]);
            vis[i * p[j]] = 1;
            if (i % p[j] == 0) break;
        }
    }
}
void dfs(int u, int fa)
{
    int v; node w;
    for (auto d : G[u])
    {
        v = d.first; w = d.second; if (v == fa) continue;
        for (auto p : V[w.y]) ++f[p];
        for (auto p : V[w.x]) --f[p], g[p] = std::min(g[p], f[p]);
        dfs(v, u);
        for (auto p : V[w.y]) --f[p];
        for (auto p : V[w.x]) ++f[p];
    }
}
void work(int u, int fa)
{
    int v, e, f; node w;
    for (auto d : G[u])
    {
        v = d.first; w = d.second; if (v == fa) continue;
        e = f = 1;
        for (auto p : V[w.y]) e = (ll)e * p % mod;
        for (auto p : V[w.x]) f = (ll)f * p % mod;
        val[v] = (ll)val[u] * e % mod * ksm(f, mod - 2) % mod;
        work(v, u);
    }
}
int main()
{
    int qwq; scanf("%d", &qwq); sieve(N - 1);
    while (qwq--)
    {
        int n; scanf("%d", &n);
        for (int i = 1, x, y, a, b; i < n; ++i)
            scanf("%d%d%d%d", &x, &y, &a, &b), insert(x, y, a, b);
        dfs(1, 0); val[1] = 1;
        for (int i = 1; i <= tp; ++i)
            if (g[p[i]] < 0) val[1] = (ll)val[1] * ksm(p[i], -g[p[i]]) % mod;
        work(1, 0); int ans = 0;
        for (int i = 1; i <= n; ++i) (ans += val[i]) %= mod;
        printf("%d\n", ans);
        for (int i = 1; i <= n; ++i) f[i] = g[i] = 0, G[i].clear();
    }
    return 0;
}

E. Arithmetic Operations

给出一个长为 \(n\) 的序列 \(a\),修改一些数使它能为等差数列。(修改后的数可以是任意整数)求最小的修改次数。(\(1\le n,a_i\le 10^5\))

\(m=\max a_i\)。考虑一旦公差确定,且某一项保留也确定了,那整个等差数列就确定了。考虑两个项能被同时保留,当且仅当它们对应的首项和公差相同。如果我们枚举后者,\(\mathcal{O}(1)\) 求前者,则能在 \(\mathcal{O}(nm)\) 的时间复杂度内求解。

但注意到,当公差变得比较大的时候,能保留的项也会变少,如果我们还按照朴素做法做的话,会浪费很多时间。所以我们有个初步的想法,公差小的时候和大的时候用两种算法。

考虑当公差为 \(w\) 时,两个下标 \(i,j(j\ge i)\) 就不能同时保留,当且仅当:

\[j-i\ge \dfrac{m}{w} \]

因为注意到,\(i,j\) 之间的距离要为 \(m\),而这已经是最大值了,所以不可能同时保留。考虑对于能同时保留的 \(i,j(j\ge i)\),从 \(i\)\(j\) 连一条边,边权为 \(\frac{a_j-a_i}{j-i}\),如果边权是实数则不连。则我们会发现,边数是 \(\mathcal{O}(\frac{mn}{w})\) 级别的。这样我们的问题就是,这张 \(\sf dag\) 上,相同边权的路径长度的最大值,这个可以用拓扑排序在 \(\mathcal{O}(\frac{mn}{w})\) 的时间复杂度内求解。

观察刚刚我们的讨论,朴素算法是 \(\mathcal{O}(mn)\) 的,公差为 \(w\) 的拓扑算法是 \(\mathcal{O}(\frac{mn}{w})\) 的。后者在公差小的时候会退化,而前者可以处理好公差小的情况。考虑根号分治,令 \(mn=\frac{mn}{w}\),则 \(w=\sqrt{m}\)。这样,我们把 \(<\sqrt{m}\) 的边权暴力枚举算,\(\ge \sqrt{m}\) 的边权用拓扑排序算。总时间复杂度 \(\mathcal{O}(n\sqrt{m})\)

#pragma GCC optimize("Ofast")
#include <cmath>
#include <queue>
#include <cstdio>
#include <algorithm>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
inline void read(int& x)
{
    x = 0; char ch; 
    while ((ch = getchar()) < '0' || ch > '9') ;
    while (x = (x << 1) + (x << 3) + ch - '0', (ch = getchar()) >= '0' && ch <= '9') ;
}
const int N = 1e5 + 10, B = 210; int a[N], n, m, ans = N;
namespace work1
{
    __gnu_pbds::gp_hash_table<int, int> mp;
    void main(int d)
    {
        int mx = 0; mp.clear();
        for (int i = 1, s = 0; i <= n; ++i, s += d) ++mp[a[i] - s];
        for (auto v : mp) mx = std::max(mx, v.second);
        ans = std::min(ans, n - mx);
    }
}
namespace work2
{
    std::vector<std::pair<int, int>> G[N]; int in[N];
    __gnu_pbds::gp_hash_table<int, int> f[N];
    void main() 
    {
        for (int i = 1; i <= n; ++i) G[i].clear(), f[i].clear();
        for (int i = 1; i <= n; ++i)
            for (int j = i + 1; j <= i + B && j <= n; ++j)
                if ((a[j] - a[i]) % (j - i) == 0 && a[j] >= a[i])
                    G[i].emplace_back(j, (a[j] - a[i]) / (j - i)), ++in[j];
        std::queue<int> q; int mx = 0;
        for (int i = 1; i <= n; ++i) if (!in[i]) q.push(i);
        while (!q.empty())
        {
            int u = q.front(), v, w; q.pop();
            for (auto d : G[u])
            {
                v = d.first; w = d.second;
                f[v][w] = std::max(f[v][w], f[u][w] + 1);
                if (w >= B) mx = std::max(mx, f[v][w]);
                if (!--in[v]) q.push(v);
            }
        }
        ans = std::min(ans, n - mx - 1);
    }
}
int main()
{
    read(n);
    for (int i = 1; i <= n; ++i) read(a[i]), m = std::max(m, a[i]);
    for (int d = 0; d < B; ++d) work1::main(d);
    work2::main(); std::reverse(a + 1, a + n + 1);
    for (int d = 0; d < B; ++d) work1::main(d);
    work2::main(); printf("%d\n", ans); return 0;
}

F. Minimal String Xoration

给出一个长为 \(2^n\) 的字符串 \(s\),下标从 \(0\) 开始。对于 \(j(0\le j<2^n)\),我们定义 \(t_j\),其中 \(t_{j,i}=s_{i\operatorname{xor} j}\)。求字典序最小的 \(t_j\)。(\(1\le n\le 18\))

我们发现,如果我们把 \(\operatorname{xor}\) 换为 \(+\),然后把超出范围的下标排除掉,则 \(t_j\) 其实就是后缀,原问题其实就是后缀排序。而把 \(+\) 换成 \(\operatorname{xor}\) 思路是类似的,我们考虑倍增。

考虑设 \(a_k\) 表示比较了前 \(2^k\) 个字符的大小关系,第 \(p\) 小的 \(t\)\(t_{a_{k,p}}\)。则我们现在的想法是通过 \(a_k\) 得到 \(a_{k+1}\)。考虑设 \(v\)\(v_i<v_j\),当且仅当 \(t_i<t_j\),这个可以根据 \(a_k\) 得出。然后我们类似后缀排序的思路,按照双关键字 \((v_i,v_{i\operatorname{xor}2^k})\) 进行排序,即可得到 \(a_{k+1}\)

\(a_0\) 可以按照朴素方法得到。按照上面的思路得到 \(a_{n}\) 后,答案即为 \(t_{a_{n,0}}\)。用 std::sort 可以做到 \(\mathcal{O}(2^nn^2)\) 的时间复杂度,如果改用基数排序,可以做到 \(\mathcal{O}(2^nn)\)。均可通过。

#include <cstdio>
#include <algorithm>
const int K = 19, N = 1 << K; char s[N]; int a[N], v[N], t[N];
int main()
{
    int k, n; scanf("%d%s", &k, s); n = 1 << k;
    for (int i = 0; i < n; ++i) a[i] = i;
    std::sort(a, a + n, [](const int& i, const int& j) { return s[i] < s[j]; });
    for (int i = 1; i < n; ++i) v[a[i]] = v[a[i - 1]] + (s[a[i]] != s[a[i - 1]]);
    for (int l = 1; l < n; l <<= 1)
    {
        auto cmp = [&](const int& i, const int& j)
        { return v[i] == v[j] ? (v[i ^ l] < v[j ^ l]) : v[i] < v[j]; };
        std::sort(a, a + n, cmp);
        for (int i = 1; i < n; ++i) t[a[i]] = t[a[i - 1]] + cmp(a[i - 1], a[i]);
        for (int i = 0; i < n; ++i) v[i] = t[i];
    }
    for (int i = 0; i < n; ++i) putchar(s[i ^ a[0]]);
    return 0;
}

G. Snowy Mountain

给出一棵 \(n\) 个结点的树,有一些结点是关键点,关键点高度是 \(0\),定义一个非关键点的高度 \(h_u\) 是离最近的关键点的距离。从一个结点出发,只能走到高度小于等于它的结点。从 \(i\) 走到 \(j\),如果 \(h_i=h_j\),则消耗 \(1\) 体力,如果 \(h_i>h_j\),则恢复 \(1\) 体力。对于每个结点,求从该结点出发能走到的结点数量最大值。(\(2\le n\le 2\times10^5\))

考虑从一个结点出发的最优策略是什么,注意到因为走的是一棵树,且只能向下走,所以最优策略是:

\(u\) 出发,走到一个尽量低的非关键结点 \(x\),满足 \(x\) 有一个相邻的结点 \(y\),且它们高度相同。在 \(x,y\) 之间来回走动消耗完全部体力。然后再走到关键点。

这样能在必然要走的 \(h_u\) 个结点基础上,多走 \(h_u-h_x\) 个结点,共 \(2h_u-h_x\) 个结点。则我们现在的问题变为了,对于所有结点,找到最小的 \(h_x\) 满足 \(u\) 能到达 \(x\)。(\(x\) 满足刚刚最优策略里的条件)如果不存在这样的 \(x\),则就只能找 \(u\)

但现在这道题还是不太好做,因为要对每个结点分别求解。不过我们发现,似乎满足最优策略条件的点不会太多?考虑设 \(S\) 表示满足条件的点集合。这里写的 CF 老哥的证明。考虑把选择某个关键点作为原树的根,则满足 \(v\in S\) 的点 \(v\) 一定在根与某个关键点路径的中间边的两端,选择深度较深的那个,与那个关键点连边,边的长度即为高度。之后,我们删去其他关键点和根的路径,因为这些路径上不可能再有关键点了。这样,原树被分为了若干类似的森林。重复上述过程,发现我们连出的边互不相交,这样 \(\sum_{v\in S}h_v\) 就只有 \(\mathcal{O}(n)\) 级别了。原评论链接

既然 \(\sum\limits_{v\in S}h_v\)\(\mathcal{O}(n)\) 级别的,那本质不同的 \(h_v(v\in S)\) 就仅有 \(\mathcal{O}(\sqrt{n})\) 级别了。既然这么少,我们考虑对于每种 \(h_v\) 分别求解。

考虑对于每种可能的高度(不一定 \(v\in S\)),计算它能到达的最低高度,满足这个高度有结点 \(\in S\),定义这种高度为关键高度。设 \(dis_{u,v}\) 表示从 \(u\) 结点到第 \(v\) 个关键高度,需要的最小体力。考虑先往下面的高度转移,然后再同层摊开转移。

对于往下面的高度转移,考虑枚举 \(h_v=h_u-1\),则有:

\[dis_{u,k}=\min(\max(0,dis_{v,k}-1),dis_{u,k}) \]

\(0\)\(\max\) 是为了不出现预支体力的情况,即出现负数。然后再在同层转移,具体来讲,两遍 \(\rm dfs\),一遍从儿子转移到父亲,一遍从父亲转移到儿子即可。(这是 std 的实现,我没想到咋实现)转移时:

\[h_{u,k}=\min(h_{v,k},h_{u,k}+1) \]

最后统计答案即可。时间复杂度 \(\mathcal{O}(n\sqrt{n})\)

#include <queue>
#include <cstdio>
#include <vector>
const int N = 2e5 + 10; int l[N], h[N], vis[N], flip[N], n; std::vector<int> T[N];
struct edge{ int x, y; }E[N]; int book[N], f[N], to[N], tp;
std::vector<int> pnt[N], dis[N];
void getHeight()
{
    std::queue<int> q; 
    for (int i = 1; i <= n; ++i) if (l[i]) q.push(i), vis[i] = 1;
    while (!q.empty())
    {
        int u = q.front(); q.pop();
        for (auto v : T[u])
        {
            if (vis[v]) continue;
            h[v] = h[u] + 1; vis[v] = 1; q.push(v);
        }
    }
}
void up(int u, int fa)
{
    book[u] = 1;
    for (auto v : T[u])
    {
        if (v == fa || h[v] != h[u]) continue;
        up(v, u);
        for (int i = 1; i <= tp; ++i)
            dis[u][i] = std::min(dis[u][i], dis[v][i] + 1);
        dis[u][to[h[u]]] = 1;
    }
}
void down(int u, int fa)
{
    for (auto v : T[u])
    {
        if (v == fa || h[v] != h[u]) continue;
        for (int i = 1; i <= tp; ++i)
            dis[v][i] = std::min(dis[v][i], dis[u][i] + 1);
        dis[v][to[h[v]]] = 1; down(v, u);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d", l + i);
    for (int i = 1, x, y; i < n; ++i)
        scanf("%d%d", &x, &y), T[x].push_back(y), T[y].push_back(x),
        E[i].x = x, E[i].y = y;
    getHeight(); for (int i = 1; i < n; ++i) if (h[E[i].x] == h[E[i].y]) flip[h[E[i].x]] = 1;
    for (int i = 1; i <= n; ++i) pnt[h[i]].push_back(i);
    for (int i = 0; i <= n; ++i) if (flip[i]) f[++tp] = i, to[i] = tp;
    for (int i = 1; i <= n; ++i) dis[i].resize(tp + 10);
    for (int i = 0; i <= n; ++i)
    {
        for (auto u : pnt[i])
        {
            for (int j = 1; j <= tp; ++j) dis[u][j] = 1e9;
            for (auto v : T[u])
                if (h[v] == h[u] - 1) 
                    for (int j = 1; j <= tp; ++j) 
                        dis[u][j] = std::min(dis[u][j], std::max(0, dis[v][j] - 1));
        }
        for (auto u : pnt[i]) if (!book[u]) up(u, 0), down(u, 0);
    }
    for (int i = 1; i <= n; ++i)
    {
        int ret = h[i];
        for (int j = 1; j <= tp; ++j) 
            if (!dis[i][j]) ret = std::max(ret, 2 * h[i] - f[j]);
        printf("%d ", ret);
    }
    puts(""); return (0 - 0); 
}

H. Three Minimums

对于一个排列 \(p\),我们定义它是好的,当且仅当它的所有长度大于 \(2\) 的子区间 \([l,r]\) 均满足以下条件:

  • 如果 \(p_l,p_r\) 分别是最/次小值(或者分别是次/最小值),则 \((l,r)\) 内的最小值位于 \(p_{l+r}\)\(p_{r-1}\)

给出一个长为 \(m\) 的字符串 \(s\)\(s_i\in\{\mathtt{<,>}\}\) 表示对 \(p_i,p_{i+1}\) 之间大小关系的要求。求有多少 \(1\sim n\) 的排列满足是好的,而且符合 \(s\) 的要求,答案对 \(998,244,353\) 取模。(\(2\le n\le 2\times 10^5,1\le m\le \min(100,n-1)\))

这题就是硬往 ODE 形式凑的吧,条件看着好怪。

遇到新定义,我们先来考虑把原题的条件进行一个转的化。首先忽略掉恶心的 \(s\),然后对于一个排列,本题中最特殊的就是最小值了,所以考虑按照最小值讨论一下。

  • 最小值在排列的中间,形式化讲,\(1<i<n\)\(i\) 为最小值所在位置。此时排列大概长这样:

    \[1\qquad i\qquad n \]

    那我们发现,这个排列是好的的充要条件是,\([1,i],[i,n]\) 都是好的。
    • 充分性,既然 \([1,i],[i,n]\) 都是好的,那我们需要说明的是,跨过 \(i\) 的都是好的。显然跨过 \(i\) 的区间,左右端点都不会成为最小值。
    • 必要性,比较显然,毕竟 \([1,i],[i,n]\subset [1,n]\)
  • 最小值在排列的端点,形式化讲,\(i=1\)\(i=n\),下面讨论不妨令 \(i=1\)。考虑次小值的位置 \(j\),如果它在排列的中间,即 \(1<j<n\),则排列长这样:

    \[i\qquad j\qquad n \]

    此时,这个排列是好的的充要条件是,\([1,j],[j,n]\) 都是好的。证明类似 \(i\) 在中间的情况。
  • 最/次小值分列排列两端,形式化讲,\(i=1,j=n\)\(i=n,j=1\),下面讨论不妨令 \(i=1,j=n\)。排列长这样:

    \[i\qquad k\qquad j \]

    其中 \(p_k=n\)。此时,这个排列好的的充要条件是,\([i,k]\) 上排列递增,\([k,j]\) 上排列递减。
    • 充分性,考虑跨过 \(k\) 和不跨过 \(k\) 的区间。只有前者的端点满足条件,而由于性质,一定满足题意。
    • 必要性,考虑如果不是这样,即出现两个峰值 \(k_1,k_2\)

      \[i\qquad k_1\qquad k_2\qquad j \]

      则考虑区间 \([k_1-q,j]\),满足条件,但次小值会出现在 \((k_1,k_2)\) 内,不满足题意。(\(q\) 是某一常数)

这样,自然地,原排列被分成了三类,细分下来是 \(5\) 类,我们按照原题解的方式给它们定义一些符号。

  • \(a_{\ast\ast}(l,r)\),没有限制,表示区间 \([l,r]\) 对应的好的排列数量,下面的类似。
  • \(a_{1\ast}(l,r)\),最小值在开头。
  • \(a_{\ast1}(l,r)\),最小值在结尾。
  • \(a_{12}(l,r)\),最小值在开头,次小值在结尾。
  • \(a_{21}(l,r)\),次小值在开头,最小值在结尾。

然后就可以通过上面的转化条件来推一推它们之间的关系了。不过在此之前,是时候捡起来刚刚扔掉的 \(s\),处理一下限制了。考虑设 \(q(x,c)=[x>m\operatorname{or}s_x=c],c\in\{\mathtt{<,>}\}\),把它们加入式子就能处理 \(s\) 了。

  • \(a_{\ast\ast}(l,r)=\sum\limits_{i=l}^r a_{\ast 1}(1,i)a_{1\ast}(i,r)\binom{r-l}{i-l}\),后面的组合数是,枚举 \(r-l\) 的值,哪些分给前面的排列。组合意义是,第一种情况下,枚举 \(i\) 的位置。
  • \(a_{1\ast}(l,r)=\begin{cases}\sum\limits_{i=l+1}^r a_{12}(l,i)a_{1\ast}(i,r)\binom{r-l-1}{i-l-1}&l<r\\1&l=r\end{cases}\),组合意义是,第二种情况下,枚举 \(j\) 的位置。
  • \(a_{\ast1}(l,r)=\begin{cases}\sum\limits_{i=l}^{r-1} a_{\ast1}(l,i)a_{21}(i,r)\binom{r-l-1}{i-l}&l<r\\1&l=r\end{cases}\),类似。
  • \(a_{12}(l,r)=\begin{cases}q(l,\mathtt{<})a_{21}(l+1,r)+q(r-1,\mathtt{>})a_{12}(l,r-1)&l+2<r\\ q(l,\mathtt{<})q(l+1,\mathtt{>})&l+2=r\\ q(l,\mathtt{<})&l+1=r\\0&l=r\end{cases}\),组合意义是,注意到第三种情况下,我们枚举 \((l,r)\) 的最小值在哪,然后就是个子问题了。
  • \(a_{12}(l,r)=\begin{cases}q(l,\mathtt{<})a_{21}(l+1,r)+q(r-1,\mathtt{>})a_{12}(l,r-1)&l+2<r\\ q(l,\mathtt{<})q(l+1,\mathtt{>})&l+2=r\\ q(l,\mathtt{>})&l+1=r\\0&l=r\end{cases}\),类似。

有了这些式子,我们现在已经能在 \(\mathcal{O}(n^2)\) 的时间复杂下计算了,期望得分 \(\tt 0pts\)。(?)考虑优化,发现最终的答案形式是:

\[a_{\ast\ast}(1,n)=\sum_{i=1}^{n}a_{\ast 1}(1,i)a_{1\ast}(i,n)\binom{n-1}{i-1} \]

所以,现在的主要问题是求出 \(a_{\ast 1}(1,k),a_{1,\ast}(k,n)(1\le k\le n)\)

我们先来讨论一下 \(a_{\ast 1}(1,k)\)。套入刚刚的计算式,我们得到:

\[a_{\ast 1}(1,k)=\sum_{i=1}^{k-1}a_{\ast 1}(1,i)a_{21}(i,k)\binom{k-2}{i-1}(k\ge 2) \]

注意到,我们有 \(a_{21}(k,k)=0\),所以完全可以变为:

\[a_{\ast 1}(1,k)=\sum_{i=1}^{k}a_{\ast 1}(1,i)a_{21}(i,k)\binom{k-2}{i-1}(k\ge 2) \]

然后,考虑把组合数拆成阶乘:

\[\dfrac{a_{\ast 1}(1,k)}{(k-2)!}=\sum_{i=1}^k\dfrac{a_{\ast 1}(1,i)}{(i-1)!}\dfrac{a_{21}(i,k)}{(k-1-i)!}(k\ge 2) \]

注意到左右有一个地方形式很统一,考虑换个元:

\[x_i=\dfrac{a_{\ast1}(1,i+1)}{i!} \]

然后令 \(k\) 表示 \(k-1\)\(i\) 表示 \(i-1\),则有(检查一下 \(k=0,1\) 的情况,发现也满足,所以可以把 \(k\ge 2\) 的条件扔掉:):

\[kx_k=\sum_{i=0}^{k-1}x_i\dfrac{a_{21}(i,k)}{(k-1-i)!} \]

推到这里,我们大概能发现卷积的形式了,但非常遗憾,\(a_{21}(i,k)\) 这个东西,如果想要预处理只能做到 \(\mathcal{O}(n^2)\),依然没有什么用。

这里有一个 key observation,注意到一旦 \(l>m\),则一定有 \(a(l,r)=a(l+1,r+1)\),所以许多大于 \(m\) 的状态是相同的。考虑先全部按照 \(l>m\) 的简单情况计算,再单独处理 \(\le m\) 的情况,毕竟后者非常好算。考虑设:

\[b(k)=a(n+1,n+k) \]

来表示所有 \(l>m\) 的情况,则原式子变为:

\[kx_k=\sum_{i=0}^{k-1}x_i\dfrac{b_{21}(k-i+1)}{(k-1-i)!}+\sum_{i=0}^{\min(k-1,m-1)}x_i\dfrac{a_{21}(i+1,k+1)-b_{21}(k-i+1)}{(k-1-i)!} \]

看起来有点复杂,但其实基本思想跟刚刚说的一样,前面的是通常情况,后面的是 \(\le m\) 的调整。而又注意到不带限制的情况 \(b_{21}\) 是特别好递推的:

\[b_{21}(k)=2b_{21}(k-1) \]

接下来我们剩下要做的事就是整理整理,转化成我们早就想要的 GF 形式了。考虑设:

\[u_k=\dfrac{b_{21}(k+2)}{k!}, v_k=\sum_{i=0}^{\min(k,m-1)}\dfrac{a_{21}(i+1,k+2)-b_{21}(k-i+2)}{(k-i)!}\]

则原式变为:

\[kx_k=v_{k-1}+\sum_{i=0}^{k-1}x_iu_{k-1-i} \]

完美符合 GF 的形式!不过还是走一下形式吧,设三个 GF:

\[F(z)=\sum_{n\ge 0}x_nz^n,G(z)=\sum_{n\ge 0}v_nz^n,H(z)=\sum_{n\ge 0}u_nz^n \]

则:

\[F'(z)=G(z)+F(z)H(z) \]

非常标准的常微分方程(ODE)形式。有两种办法,一种是数学方法爆解,具体解可以去看原题解。而对于我这种高数很差的人来说,我选择半在线卷积。把 \(k\) 当常数除过去,就是标准的半在线卷积形式了,记得做贡献要位移,\(\frac{1}{k}\) 在分治边界乘上即可。

总结一下为了得到 \(a_{\ast 1}(1,k)\) 我们都需要干点什么。首先需要在 \(\mathcal{O}(nm)\) 的时间复杂度下预处理出 \(a_{21}(l,r)(l\in[1,m],r\in[1,n])\),在 \(\mathcal{O}(n)\) 下预处理出 \(b_{21}\),然后在 \(\mathcal{O}(m^2)\) 下预处理出 \(x_i(i\in[1,m])\)。最后做一次半在线卷积,总时间复杂度 \(\mathcal{O}((n+m)m+n\log^2n)\)

然后,我们转回头来看看差不多的 \(a_{1\ast}(k,n)\) 怎么办。说是差不多,其实还有点不一样,因为 \(a_{1\ast}(k,n)\)\([1,m]\) 看起来更远了一点,所以对于 \(a_{\ast1}(1,k)\) 推导中的 \(G(z)\) 必然要有不一样的处理。

类似的思路,我们有:

\[a_{1\ast}(k,n)=\sum_{i=k+1}^n a_{12}(k,i)a_{1\ast}(i,n)\binom{n-k-1}{i-k-1} \]

然后拆组合数:

\[\dfrac{a_{1\ast}(k,n)}{(n-k-1)!}=\sum_{i=k+1}^n\dfrac{a_{12}(k,i)}{(i-k-1)!}\dfrac{a_{1\ast}(i,n)}{(n-i)!} \]

换元,不过这里为了凑出 ODE 的形式设的有点不太一样:

\[x_k=\dfrac{a_{1\ast}(n-k,n)}{(k-1)!} \]

然后就会变成:

\[(n-k)x_{n-k}=\sum_{i=k+1}^{n}x_{n-i}\dfrac{a_{12}(k,i)}{(i-k-1)!} \]

不好看,换个形式。令 \(k\) 代表 \(n-k\)\(i\) 代表 \(n-i\)

\[kx_k=\sum_{i=0}^{k-1}x_i\dfrac{a_{12}(n-k,n-i)}{(k-1-i)!} \]

又遇到了类似的问题,看起来很卷积,但是 \(a_{12}\) 不好求。这一次,如果还用类似的方法解决会怎么样呢:

\[kx_k=\sum_{i=0}^{k-1}x_i\dfrac{b_{12}(k-i+1)}{(k-1-i)!}+[k\ge n-m]\sum_{i=0}^{k-1}\dfrac{a_{12}(n-k,n-i)-b_{12}(k-i+1)}{(k-1-i)!} \]

这个式子,蛮怪的。因为我们发现,是否需要调整和 \(i\) 无关。而需要调整的地方又只有 \(m\) 处。所以我们可以考虑提前求出来这 \(m\) 个需要调整的地方,而剩下的就不需要调整了!具体来讲,我们设:

\[y_k=\dfrac{b_{1\ast}(k)}{(k-1)!} \]

然后有:

\[y_k=\sum_{i=0}^{k-1}y_i\dfrac{b_{12}(k-i+1)}{(k-1-i)!} \]

常数项非常理所应当地消失了。我们只需要设:

\[u_i=\dfrac{b_{12}(i+2)}{i!},F(z)=\sum_{n\ge 0}y_nz^n,H(z)=\sum_{n\ge 0}u_nz^n \]

就能得到非常好看的 ODE 形式:

\[F'(z)=H(z)F(z) \]

依然可以半在线卷积求解。

总结一下为了得到 \(a_{1\ast}(k,n)\) 我们需要求什么。首先 \(\mathcal{O}(nm)\) 求出 \(a_{1\ast}(k,n)(1\le k\le m)\),然后 \(\mathcal{O}(n)\) 求出 \(b_{12}(k)\)(递推 \(b_{12}\) 的公式和 \(b_{21}\) 的是一样的),最后做一次半在线卷积。得到 \(b_{1\ast}\) 后,\(a_{12}(k,n)=b_{1\ast}(n-k+1)(m<k\le n)\)。总时间复杂度 \(\mathcal{O}(nm+n\log ^2n)\)

最后,解完这两个 ODE,统计答案即可。最终时间复杂度 \(\mathcal{O}((n+m)m+n\log^2n)\)

写在代码前的友情提醒:我的代码有大常数!

#pragma GCC optimize("Ofast")
#include <cstdio>
#include <cstring>
#include <cassert>
#include <algorithm>
const int N = 1e6 + 10, mod = 998244353; typedef long long ll; 
int F[N], G[N], H[N], A[N], B[N], m, lim, rev[N];
inline int ksm(int a, int b)
{
    int ret = 1;
    while (b)
    {
        if (b & 1) ret = (ll)ret * a % mod;
        a = (ll)a * a % mod; b >>= 1;
    }
    return ret;
}
inline void init(int n)
{
    m = 0; lim = 1; while (lim <= n) lim <<= 1, ++m;
    for (int i = 0; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1));
}
inline void NTT(int* f, int len, int on)
{
    for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(f[i], f[rev[i]]);
    for (int h = 2; h <= len; h <<= 1)
    {
        int gn = ksm(3, (ll)((mod - 1) / h) * on % (mod - 1));
        for (int j = 0; j < len; j += h)
            for (int k = j, g = 1; k < j + h / 2; ++k, g = (ll)g * gn % mod)
            {
                int u = f[k], t = (ll)g * f[k + h / 2] % mod;
                f[k] = (u + t) % mod; f[k + h / 2] = ((u - t) % mod + mod) % mod;
            }
    }
    if (on == mod - 2) for (int i = 0, inv = ksm(len, mod - 2); i < len; ++i) f[i] = (ll)f[i] * inv % mod;
}
char s[N]; int b12[N], b21[N], ax1[110][110], a1x[110][110], a12[110][N], a21[110][N];
int fac[N], ifac[N], Ax1[N], A1x[N], B1x[N];
inline int C(int n, int m) { return (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
inline void initF(int n)
{
    fac[0] = ifac[0] = 1;
    for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % mod;
    ifac[n] = ksm(fac[n], mod - 2);
    for (int i = n - 1; i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % mod;
}
inline void mul(int* des, int* src, int n, int m)
{
    for (int i = 0; i < n; ++i) A[i] = des[i];
    for (int i = 0; i < m; ++i) B[i] = src[i];
    for (int i = n; i < lim; ++i) A[i] = 0;
    for (int i = m; i < lim; ++i) B[i] = 0;
    NTT(A, lim, 1); NTT(B, lim, 1);
    for (int i = 0; i < lim; ++i) A[i] = (ll)A[i] * B[i] % mod;
    NTT(A, lim, mod - 2);
}
inline void cdq(int l, int r)
{
    if (l + 1 == r)
        return !l ? void() : ((F[l] += G[l - 1]) %= mod, F[l] = (ll)F[l] * ksm(l, mod - 2) % mod, void());     
    int mid = (l + r) >> 1; cdq(l, mid); init(r - l);
    mul(F + l, H, mid - l, r - l - 1);
    for (int i = mid; i < r; ++i) (F[i] += A[i - l - 1]) %= mod;
    cdq(mid, r);
}
inline int Q(int x, char c) { return !s[x] || s[x] == c; }
int main()
{
    int n, m; scanf("%d%d%s", &n, &m, s + 1); initF(n);
    b12[2] = b21[2] = b12[3] = b21[3] = 1;
    for (int i = 4; i <= n + 2; ++i)
        b12[i] = 2ll * b12[i - 1] % mod, b21[i] = 2ll * b21[i - 1] % mod;
    for (int k = 1; k <= n; ++k)
        for (int l = 1; l <= m && l + k - 1 <= n + 1; ++l)
        {
            int r = l + k - 1;
            if (l == r) a12[l][r] = a21[l][r] = 0; 
            else if (l + 1 == r) a12[l][r] = Q(l, '<'), a21[l][r] = Q(l, '>');
            else if (l + 2 == r)
                a12[l][r] = Q(l, '<') * Q(l + 1, '>'),
                a21[l][r] = Q(l, '<') * Q(l + 1, '>');
            else
            {
                if (l < m)
                {
                    a12[l][r] = (Q(l, '<') * a21[l + 1][r] + Q(r - 1, '>') * a12[l][r - 1]) % mod;
                    a21[l][r] = (Q(l, '<') * a21[l + 1][r] + Q(r - 1, '>') * a12[l][r - 1]) % mod;
                }
                else
                {
                    a12[l][r] = (Q(l, '<') * b21[r - l] + Q(r - 1, '>') * a12[l][r - 1]) % mod;
                    a21[l][r] = (Q(l, '<') * b21[r - l] + Q(r - 1, '>') * a12[l][r - 1]) % mod;
                }
            }
        }  
    for (int k = 1; k <= m; ++k)
    {
        int r = k;
        if (1 == r) { ax1[1][r] = 1; continue; }
        for (int i = 1; i < r; ++i)
            (ax1[1][r] += (ll)ax1[1][i] * a21[i][r] % mod * C(r - 2, i - 1) % mod) %= mod;
    } 
    F[0] = 1;
    for (int k = 0; k < n; ++k) H[k] = (ll)ifac[k] * b21[k + 2] % mod;
    for (int k = 1; k <= n; ++k)
    {
        int add = 0;
        for (int i = 0; i < std::min(k, m); ++i)
            (add += (ll)ifac[k - i - 1] * (a21[i + 1][k + 1] - b21[k + 1 - i]) % mod
            * ax1[1][i + 1] % mod * ifac[i] % mod) %= mod;
        (add += mod) %= mod; G[k - 1] = add;
    }
    int k = 1; while (k <= n) k <<= 1; 
    cdq(0, k);
    for (int i = 0; i < n; ++i) Ax1[i + 1] = (ll)fac[i] * F[i] % mod;
    memset(F, 0, sizeof (F)); memset(G, 0, sizeof (G)); memset(H, 0, sizeof (H));
    F[0] = 1;
    for (int k = 0; k <= n; ++k) H[k] = (ll)ifac[k] * b12[k + 2] % mod;
    cdq(0, k); 
    for (int i = 0; i < n; ++i) B1x[i + 1] = (ll)fac[i] * F[i] % mod;
    for (int i = m + 1; i <= n; ++i) A1x[i] = B1x[n - i + 1];
    for (int k = n - m + 1; k <= n; ++k)
    {
        int l = n - k + 1;
        for (int i = l + 1; i <= n; ++i)
            (A1x[l] += (ll)a12[l][i] * A1x[i] % mod * C(n - l - 1, i - l - 1) % mod) %= mod;
    } 
    int ans = 0;
    for (int i = 1; i <= n; ++i) (ans += (ll)Ax1[i] * A1x[i] % mod * C(n - 1, i - 1) % mod) %= mod;
    printf("%d\n", ans); return 0;
}
posted @ 2022-03-24 16:13  zhiyangfan  阅读(109)  评论(1编辑  收藏  举报