[HNOI2019]JOJO

[HNOI2019]JOJO

[题目链接]

链接

[思路要点]

题目询问的是当前字符串做 \(\text{kmp}\) 之后的 \(\text{nxt}\) 数组的值的和

首先考虑没有第二种操作的情况

将添加操作看成添加一个字符,这个字符有两个属性,长度和字符。

不难发现,两个子串相匹配,每个子串拆分成开头某一段的后缀 + 中间一堆完整的段 + 结尾某一段的前缀,而题目中一定是某个前缀和某个后缀相匹配,也就是第一个串的开头是完整的段,最后一个串的结尾是完整的段,最后一个串的开头所在段比第一个串可能长,最后一个串结尾所在段比第一个串可能短

由于中间部分完全匹配,所以直接做普通的 \(\text{kmp}\) 即可,开头由于没有修改,直接令全文的第一段能够匹配所有字符和它一样但长度大于等于它的段就处理好了,关键在于结尾串的处理

考虑 \(\text{kmp}\) 的实现过程,实质上是不断跳 \(nxt\) 的过程,由于题目保证每次加入的字符不和之前最后一个字符相同,那么 \(nxt\) 跳在某一段中间时是不可能匹配的,可以减少一些情况。不妨设最后一段的字符是 c,现在跳到 \(nxt\) 有三处紧跟的字符都是 c,第一处为 (p1,c,l1),第二处为 (p2,c,l2),第三处为 (p3,c,l3),其中 p 表示 \(nxt\) 的值,也就是跳到的位置,l 表示这一段 c 的长度。假设 p1<p2<p3l1>l2>l3

那么显然,设当前新添加的 c 的长度为 l,下标为 [1,l],那么其 [1,l3] 一段应该和第三处匹配,[l3+1,l2] 一段和第二处匹配,即每次找到一段 c 都覆盖一段位置,并且不能覆盖前面覆盖过的位置(不优),每次的贡献都是一段等差数列。

这样复杂度就是 \(\Theta (m)\) 的。考虑加入第二种操作。首先可以建出操作树并 dfs 完成撤销操作,但是由于 \(\text{kmp}\) 复杂度是均摊 \(\Theta(1)\) 的,不能保证每次操作都很快,那么一个不停返回上一个状态并做一次较慢的操作就能卡掉该算法。实质上,由于总字符数范围大概在 \(1e9\) 左右,暴力 kmp 已经在通过的边缘了,可能剪剪枝就能通过了。

考虑一个叫 \(\text{kmp}\) 自动机的东西,它的本质是把 \(\text{kmp}\)\(\text{nxt}\) 的过程预处理,由于本题字符集大小很大,用主席树维护。

\(f_{i,j,k}\) 表示在串的 \(s_{i-1}\) 位置添加一个字符 \((j,k)\)\(nxt\) 所到达的位置,同理设 \(g_{i,j,k}\) 表示增加的答案。在dfs的时候,修改 \(f_{i,x,c}\) 的值,并将 \(g_{i,x,1\dots c}\)设置为首项为当前串长度,公差为 \(1\) 的等差数列。dfs下一层前把 \(f_{i+1}\) 的状态由 \(f_{next[i]+1}\) 继承过来。由于每次加入的是一个等差数列,可以预先减掉下标值,这样变成了区间赋值,统计时再加上每个下标的值即可。

[代码]

// Copyright: lzt
#include<stdio.h>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<cmath>
#include<iostream>
#include<queue>
#include<string>
#include<ctime>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<long long,long long> pll;
typedef pair<int, pair<int, long long> > lzt;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define rep(i,j,k)  for(register int i=(int)(j);i<=(int)(k);i++)
#define rrep(i,j,k) for(register int i=(int)(j);i>=(int)(k);i--)
#define Debug(...) fprintf(stderr, __VA_ARGS__)

ll read(){
    ll x=0,f=1;char c=getchar();
    while(c<'0' || c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0' && c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

inline char gc() {
    char c = getchar();
    while (c < 'a' || c > 'z') c = getchar();
    return c;
}

const int maxn = 100005, M = 1e4 + 7, mod = 998244353;

int n;
int val[maxn], pos[maxn], ans[maxn], a[maxn], b[maxn], top;
vector<int> to[maxn];

int rt[maxn][26], mx[maxn][26], tot;
struct seg
{
    int l, r, lch, rch, sum, lzy, nxt;
} t[maxn * 60];

#define mid ((l + r) >> 1)

inline void new_node(int &s) {t[++tot] = t[s]; s = tot;}
inline void add(int s, int v, int len) {t[s].sum = (ll)v * len % mod; t[s].lzy = v;}
inline void push_down(int s, int l, int r)
{
    if (!t[s].lzy) return ;
    new_node(t[s].lch); add(t[s].lch, t[s].lzy, mid - l + 1);
    new_node(t[s].rch); add(t[s].rch, t[s].lzy, r - mid);
    t[s].lzy = 0;
}

void modify(int &s, int l, int r, int x, int v, int p)
{
    new_node(s);
    if (r < x) return add(s, v, r - l + 1);
    if (l == r) return t[s].nxt = p, add(s, v, 1);
    push_down(s, l, r);
    modify(t[s].lch, l, mid, x, v, p);
    if (x > mid) modify(t[s].rch, mid + 1, r, x, v, p);
    t[s].sum = (t[t[s].lch].sum + t[t[s].rch].sum) % mod;
}

void query(int &s, int l, int r, int x, int &ans, int &nxt)
{
    if (r < x) return ans = (ans + t[s].sum) % mod, void();
    if (l == r) return ans = (ans + t[s].sum) % mod, nxt = t[s].nxt, void();
    push_down(s, l, r);
    query(t[s].lch, l, mid, x, ans, nxt);
    if (x > mid) query(t[s].rch, mid + 1, r, x, ans, nxt);
}

inline int getsum(int x) {return ((ll)x * (x + 1) >> 1) % mod;}

void dfs(int u)
{
    ++top;
    int x = val[u] / M, y = val[u] % M, nxt = 0;
    a[top] = val[u]; b[top] = b[top - 1] + y;
    if (top == 1) ans[u] = getsum(y - 1);
    else {
        ans[u] = (ans[u] + getsum(min(mx[top][x], y))) % mod;
        query(rt[top][x], 1, M, y, ans[u], nxt);
        if (!nxt && a[1] / M == x && b[1] < y) nxt = 1, ans[u] = (ans[u] + (ll)b[1] * max(0, y - mx[top][x])) % mod;
    }
    mx[top][x] = max(mx[top][x], y);
    modify(rt[top][x], 1, M, y, b[top - 1], top);
    for (int i=0;i<to[u].size();i++) {
    	int v = to[u][i];
        memcpy(mx[top + 1], mx[nxt + 1], sizeof(mx[top + 1]));
        memcpy(rt[top + 1], rt[nxt + 1], sizeof(rt[top + 1]));
        ans[v] = ans[u]; dfs(v);
    }
    --top;
}

void work() {
    n = read();
    for (int op, x, i = 1; i <= n; ++i) {
        op = read(); x = read();
        if (op == 1) val[++tot] = (gc() - 'a') * M + x, pos[i] = tot, to[pos[i - 1]].push_back(pos[i]);
        else pos[i] = pos[x];
    }

    for (int _ = 0; _ < to[0].size(); _++) {
    	int i = to[0][_];
        tot = 0;
        memset(rt[1], 0, sizeof(rt[1]));
        memset(mx[1], 0, sizeof(mx[1]));
        dfs(i);
    }

    for (int i = 1; i <= n; ++i) printf("%d\n", ans[pos[i]]);
}

int main(){
    #ifdef LZT
        freopen("in","r",stdin);
    #endif

    work();

    #ifdef LZT
        Debug("My Time: %.3lfms\n", (double)clock() / CLOCKS_PER_SEC);
    #endif
}
posted @ 2019-06-30 08:51  wawawa8  阅读(559)  评论(0编辑  收藏  举报