「解题报告」UOJ577 【ULR #1】打击复读
题目大意
给你一个长度为 \(n\) 字符串 \(s\) 与两个长度为 \(n\) 的权值数组,左权值 \(wl\) 和 右权值 \(wr\),定义一个子串的权值为该子串在 \(s\) 中的所有出现位置的左端点 \(wl\) 和与所有出现位置右端点 \(wr\) 的和的乘积。
\(m\) 次修改,每次修改单个 \(wl\) 值,每次询问求 \(s\) 中所有子串的权值之和。
\(n, m\le 5\times 10^5\)
思路
看到出现位置,我们可以考虑这东西其实就是 \(endpos\) 集合。
那么我们先考虑对原串建立一个后缀自动机,那么我们可以把答案转换成求后缀自动机上子串的权值之和。
对于同一个节点,它的 \(endpos\) 集合是不变的,也就是这同一个节点的 \(wr\) 之和是不变的,那么我们只需要求出该节点代表的所有子串的左端点权值之和就可以统计答案了。
也就是说,假设一个节点 \(u\) 代表子串的长度区间为 \([L_u,R_u]\),那么这个子串对答案的贡献就是:
那么我们现在就是要求 \(\sum_{i\in endpos_u} Swl_{i-R_u+1}\) 了。
我们发现,这其实就是子串 \(s[i - R_u + 1, i]\) 的出现位置的左端点取值。那么我们可以用类似的方法,建出其反串的 SAM,并且在反串的 SAM 上找到相对应的节点就做完了。
具体找节点方法可以使用倍增做到 \(O(n\log n)\),然后貌似还可以离线分块并查集做到 \(O(n)\)。
对于修改,我们可以将 \(wl\) 和 \(wr\) 调换(或者最一开始就建反串的 SAM),这样修改就只会修改 \(\left(\sum_{i\in endpos_u} wr_i\right)\),那么就可以通过数上链和直接统计答案的变化量。
这样子就做完了,不过我们还可以挖掘更厉害的性质。
上面使用了差分的思想,但是我们可以从另一个角度去考虑这个式子:
其中 \(f(t)\) 指 \(t\) 在 \(s\) 中出现所有位置的 \(wl\) 和。
我们还是建出正反 SAM,然后我们直接从自动机的角度去考虑这件事情。
对于在正串的 \(parent\) 树上的某一个节点,我们往下走一个节点,实际上是往前添加了几个字符。
那么对于反串上,实际就是走了几条转移边。
我们可以考虑同时在两个 SAM 上进行 DFS 来维护出这样的权值,但是这样每一次转移一步在反串 SAM 上会跳 \(R_u-L_u+1\) 次转移边,复杂度肯定是不能接受的。
有这样一个性质:对于现在正串上一个节点所代表的非最长子串,它对应的反串节点一定只有一个转移边;而如果是一个非原串后缀的节点的最长子串,它对应的反串节点一定至少有两条转移边。
因为如果它的反串节点有多于一条转移边,那么相当于我们可以在这个非最长子串前添加不同的字符,那么这个子串肯定不能是非最长子串,而是一个最长子串。需要注意的是,如果这个节点的最长串是原串的一个后缀,那么它可能也只有一条转移边,因为这个子串的最后一个出现位置是不能继续往后拓展的。
那么我们就发现,当我们在正串 \(parent\) 树上 DFS 的时候,反串一定是跳过了若干出度为 \(1\) 的非后缀节点,并且贡献的权值等于这些出度为 \(1\) 的点的权值和。那么我们就可以对反串后缀自动机进行处理,\(O(n)\) 预处理出每个点向前一直跳出度为 \(1\) 的非后缀节点到达的节点与经过的节点上的权值和。于是,我们就可以将前面的跳 \(R_u-L_u+1\) 次转移边变成跳 \(1\) 次转移边了。这样复杂度就变成了 \(O(n)\)。
于是我们就得到了一个简单的 \(O(n+m)\) 复杂度的算法。
这个算法还具有很大的扩展性,因为不使用差分,所以还可以做到维护 \(\min,\max\) 等权值,甚至可以处理更复杂的贡献函数。
这东西据说是压缩后缀自动机,但是我很菜看不懂也懒得看论文也找不到论文于是就咕了)
Code
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1000005;
typedef unsigned long long ull;
int n;
char s[MAXN];
struct SuffixAutomaton {
int t[MAXN][4], fail[MAXN], len[MAXN], pos[MAXN];
int lst[MAXN];
int tot;
SuffixAutomaton() : tot(1) { lst[0] = 1; }
void insert(int c, int i) {
int p = lst[i - 1], np = lst[i] = ++tot;
len[np] = len[p] + 1, pos[np] = i;
while (p && !t[p][c]) t[p][c] = np, p = fail[p];
if (!p) {
fail[np] = 1;
} else {
int q = t[p][c];
if (len[q] == len[p] + 1) {
fail[np] = q;
} else {
int nq = ++tot; memcpy(t[nq], t[q], sizeof t[q]);
len[nq] = len[p] + 1, pos[nq] = pos[q];
fail[nq] = fail[q], fail[q] = fail[np] = nq;
while (p && t[p][c] == q) t[p][c] = nq, p = fail[p];
}
}
}
} sam1, sam2;
vector<int> g1[MAXN], g2[MAXN];
void buildGraph(SuffixAutomaton &sam, vector<int> g[]) {
for (int i = 2; i <= sam.tot; i++) {
g[sam.fail[i]].push_back(i);
}
}
ull wl[MAXN], wr[MAXN], swr[MAXN], f[MAXN], F[MAXN];
int tim[MAXN];
void dfs1(int u) {
for (int v : g1[u]) {
dfs1(v);
tim[u] += tim[v];
swr[u] += swr[v];
}
}
int nxt[MAXN];
int cnt[MAXN], id[MAXN];
bool flag[MAXN];
void compress() {
for (int i = 1; i <= sam1.tot; i++) cnt[sam1.len[i]]++;
for (int i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
for (int i = sam1.tot; i >= 1; i--) id[cnt[sam1.len[i]]--] = i;
int p = sam1.lst[n];
while (p) flag[p] = 1, p = sam1.fail[p];
for (int i = sam1.tot; i >= 1; i--) {
int u = id[i];
swr[u] *= tim[u];
int deg = 0, v;
for (int i = 0; i < 4; i++) if (sam1.t[u][i]) {
deg++;
v = sam1.t[u][i];
}
if (!flag[u] && deg == 1) {
nxt[u] = nxt[v];
swr[u] += swr[v];
} else nxt[u] = u;
}
}
void dfs2(int u, int p) {
for (int v : g2[u]) {
int i = s[n - sam2.pos[v] + 1 + sam2.len[u]];
f[v] = f[u] + swr[sam1.t[p][i]];
dfs2(v, nxt[sam1.t[p][i]]);
}
}
int m;
int main() {
scanf("%d%d%s", &n, &m, s + 1);
for (int i = 1; i <= n; i++) {
if (s[i] == 'A') s[i] = 0;
if (s[i] == 'C') s[i] = 1;
if (s[i] == 'T') s[i] = 2;
if (s[i] == 'G') s[i] = 3;
}
for (int i = 1; i <= n; i++) scanf("%llu", &wl[i]);
for (int i = 1; i <= n; i++) scanf("%llu", &wr[i]);
for (int i = 1; i <= n; i++) sam1.insert(s[i], i);
for (int i = n; i >= 1; i--) sam2.insert(s[i], n - i + 1);
for (int i = 1; i <= n; i++) tim[sam1.lst[i]] = 1, swr[sam1.lst[i]] = wr[i];
buildGraph(sam1, g1), buildGraph(sam2, g2);
dfs1(1), compress(), dfs2(1, 1);
for (int i = 1; i <= n; i++) F[i] = f[sam2.lst[n - i + 1]];
ull ans = 0;
for (int i = 1; i <= n; i++) ans += wl[i] * F[i];
printf("%llu\n", ans);
for (int i = 1; i <= m; i++) {
int p; ull w; scanf("%d%llu", &p, &w);
ans += (w - wl[p]) * F[p];
wl[p] = w;
printf("%llu\n", ans);
}
return 0;
}