[BZOJ3473][BZOJ3277]字符串
[BZOJ3473][BZOJ3277]字符串
试题描述
给定 \(n\) 个字符串,询问每个字符串有多少子串(不包括空串)是所有 \(n\) 个字符串中至少 \(k\) 个字符串的子串?
输入
第一行两个整数 \(n\),\(k\)。
接下来 \(n\) 行每行一个字符串。
输出
一行 \(n\) 个整数,第 \(i\) 个整数表示第 \(i\) 个字符串的答案。
输入示例
3 1
abc
a
ab
输出示例
6 1 3
数据规模及约定
对于 \(100\%\) 的数据,\(1 \le n\),\(k \le 10^5\),所有字符串总长不超过 \(10^5\),字符串只包含小写字母。
题解
注意这题的子串指的是位置不同的子串算不同的子串。
这个题我暴力后缀数组 + 主席树 + 并查集 + 离线标记做的,估计比我简单的做法有不少……反正这也算一道开放性的题目了吧,我是顺着思路就把它想下来了,并没有管实现由多麻烦……
首先将所有串拼接在一起,中间用特殊的分隔符隔开(每个分隔符都要独一无二),然后进行后缀排序。
现在我们考虑的问题都是在这个后缀序列上的。
对于某个长度为 \(l\) 的符合题意的子串,一定在一段区间中,这段区间的 \(height\) 值都 \(\ge l\);并且我们知道对于这段区间中每个后缀,它们的长度 \(\le l\) 的前缀都是满足要求的子串,所以对于这段区间中的每个后缀的贡献就是 \(l\) 个子串。
题目还要求这样的区间中所有后缀的起点在至少 \(k\) 个串中,才能累计贡献,令能累计贡献的区间叫做合法区间。
除此之外还需要注意的一点是:每个后缀的区间只能被它所属的区间内最小 \(height\) 值最大的合法区间贡献一次,否则会重复(因为对于 \(l' < l\) 找到的合法区间一定也会包含它,而这时该后缀的长度 \(\le l'\) 的前缀已经在长度为 \(l\) 时统计过了)。
所以接下来算法过程就很明显了,从大到小枚举 \(l\),然后并查集合并区间,当区间内部第一次有超过 \(k\) 个串时给这个区间内每个后缀的计数器上加一个 \(l\),最后对于每个后缀看它起点在哪个串中并把这个后缀上计数器的值加到那个串的答案中去就好了。
查询区间内部有多少个串要用到主席树。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <cmath>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
#define maxn 200010
#define maxnode 3600180
#define LL long long
char S[maxn];
int n, str[maxn];
int sa[maxn], Ws[maxn], rnk[maxn], height[maxn];
bool equ(int *a, int p1, int p2, int len) {
if(p1 + len > n && p2 + len > n) return a[p1] == a[p2];
if(p1 + len > n || p2 + len > n) return 0;
return a[p1] == a[p2] && a[p1+len] == a[p2+len];
}
void ssort() {
int *x = rnk, *y = height, m = 0;
rep(i, 1, n) Ws[x[i] = str[i]]++, m = max(m, str[i]);
rep(i, 1, m) Ws[i] += Ws[i-1];
dwn(i, n, 1) sa[Ws[x[i]]--] = i;
for(int j = 1, pos; j < n; j <<= 1, m = pos) {
pos = 0;
rep(i, n - j + 1, n) y[++pos] = i;
rep(i, 1, n) if(sa[i] > j) y[++pos] = sa[i] - j;
rep(i, 1, m) Ws[i] = 0;
rep(i, 1, n) Ws[x[i]]++;
rep(i, 1, m) Ws[i] += Ws[i-1];
dwn(i, n, 1) sa[Ws[x[y[i]]]--] = y[i];
swap(x, y); x[sa[1]] = pos = 1;
rep(i, 2, n) x[sa[i]] = equ(y, sa[i], sa[i-1], j) ? pos : ++pos;
}
return ;
}
void calch() {
rep(i, 1, n) rnk[sa[i]] = i;
for(int i = 1, j, k = 0; i <= n; height[rnk[i++]] = k)
for(k ? k-- : 0, j = rnk[i]; str[sa[j]+k] == str[sa[j-1]+k]; k++);
return ;
}
int at[maxn], lstv[maxn], lst[maxn], ToT, rt[maxn], sumv[maxnode], lc[maxnode], rc[maxnode];
void update(int& y, int x, int l, int r, int p) {
sumv[y = ++ToT] = sumv[x] + 1;
if(l == r) return ;
int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x];
if(p <= mid) update(lc[y], lc[x], l, mid, p);
else update(rc[y], rc[x], mid + 1, r, p);
return ;
}
int query(int o, int l, int r, int qr) {
if(!o) return 0;
if(r <= qr) return sumv[o];
int mid = l + r >> 1, ans = query(lc[o], l, mid, qr);
if(qr > mid) ans += query(rc[o], mid + 1, r, qr);
return ans;
}
int Query(int ql, int qr) { return query(rt[qr], 0, n, ql - 1) - query(rt[ql-1], 0, n, ql - 1); }
int fa[maxn], siz[maxn], L[maxn], R[maxn];
int findset(int x) { return x == fa[x] ? x : fa[x] = findset(fa[x]); }
int id[maxn];
bool cmp(int a, int b) { return height[a] < height[b]; }
int addv[maxn];
void puttag(int l, int r, int v) {
// printf("puttag: %d %d %d\n", l, r, v);
addv[l] += v; addv[r+1] -= v;
return ;
}
LL Ans[maxn];
int main() {
int N = read(), K = read();
if(K <= 1) {
rep(i, 1, N) {
scanf("%s", S);
int l = strlen(S);
printf("%lld%c", (LL)l * (l + 1) >> 1, i < N ? ' ' : ' ');
}
return 0;
}
int alpha = 26;
rep(i, 1, N) {
scanf("%s", S + 1); int l = strlen(S + 1);
rep(j, 1, l) str[n+j] = S[j] - 'a' + 1, at[n+j] = i;
n += l;
str[++n] = ++alpha;
}
ssort();
calch();
rep(i, 1, n) {
lst[i] = lstv[at[sa[i]]];
if(at[sa[i]]) update(rt[i], rt[i-1], 0, n, lst[i]);
else rt[i] = rt[i-1];
lstv[at[sa[i]]] = i;
fa[i] = L[i] = R[i] = i;
siz[i] = at[sa[i]] > 0;
id[i] = i;
}
sort(id + 1, id + n + 1, cmp);
dwn(I, n, 1) {
int i = id[I];
if(i == 1) continue;
if(!height[i]) break;
int u = findset(i - 1), v = findset(i);
if(u != v) {
fa[v] = u;
int nowL = min(L[u], L[v]), nowR = max(R[u], R[v]), nows = Query(nowL, nowR);
if(nows >= K) {
if(siz[u] < K) puttag(L[u], R[u], height[i]);
if(siz[v] < K) puttag(L[v], R[v], height[i]);
}
siz[u] = nows;
L[u] = nowL; R[u] = nowR;
}
}
rep(i, 1, n) {
addv[i] += addv[i-1];
Ans[at[sa[i]]] += addv[i];
}
rep(i, 1, N) printf("%lld%c", Ans[i], i < N ? ' ' : ' ');
return 0;
}