【bzoj4516】[Sdoi2016]生成魔咒 后缀数组+倍增RMQ+STL-set
题目描述
魔咒串由许多魔咒字符组成,魔咒字符可以用数字表示。例如可以将魔咒字符 1、2 拼凑起来形成一个魔咒串 [1,2]。一个魔咒串 S 的非空字串被称为魔咒串 S 的生成魔咒。
例如 S=[1,2,1] 时,它的生成魔咒有 [1]、[2]、[1,2]、[2,1]、[1,2,1] 五种。S=[1,1,1] 时,它的生成魔咒有 [1]、[1,1]、[1,1,1] 三种。最初 S 为空串。共进行 n 次操作,每次操作是在 S 的结尾加入一个魔咒字符。每次操作后都需要求出,当前的魔咒串 S 共有多少种生成魔咒。
输入
第一行一个整数 n。
第二行 n 个数,第 i 个数表示第 i 次操作加入的魔咒字符。
1≤n≤100000。,用来表示魔咒字符的数字 x 满足 1≤x≤10^9
输出
输出 n 行,每行一个数。第 i 行的数表示第 i 次操作后 S 的生成魔咒数量
样例输入
7
1 2 3 3 3 1 2
样例输出
1
3
6
9
12
17
22
题解
后缀自动机+map 后缀数组+set
SAM太高端了,于是我选择了后缀数组。
好在这题可以离线。
题目中求的是前缀的字串的个数,我们用的是后缀数组,于是需要先读入所有数字,离散化后倒序求sa、rank和height。
由于不重复子串的个数为为总子串个数-重复子串个数,所以我们每次加一个后缀(题目中的前缀),只需求出该串和其它串的重复部分即可。
而该后缀不包括第一个字符的所有串都一定在之前被计算过,不用考虑,只需考虑包括第一个字符的串是否重复出现过即可。
根据sa和rank的定义,挨着的串LCP(最长公共前缀)最大,那么本次计算出的重复的串的个数为之前的后缀的rank值与该后缀的rank值最接近的两个,求出它们与该后缀的LCP中较大的那个算入tmp中。
可能比较难理解。
举个例子,3 3 3 4 3 1中计算后5个以后算后6个,只需要算max(lcp(1,2),lcp(1,5))即可,因为3和33为重复的,取最大值之后是2个重复。
然后把len-tmp加入到答案中并输出即可。
求前驱后继可以使用set,求LCP可以用倍增算法RMQ,此时注意左端点需要+1。
#include <cstdio> #include <algorithm> #include <set> #define N 100010 using namespace std; struct data { int num , p; }a[N]; set<int> s; set<int>::iterator it; int n , v[N] , m , sa[N] , r[N] , ws[N] , wa[N] , wb[N] , wv[N] , rank[N] , height[N] , log[N] , f[N][20]; bool cmp(data a , data b) { return a.num < b.num; } void da() { int i , j , p , *x = wa , *y = wb; for(i = 0 ; i < m ; i ++ ) ws[i] = 0; for(i = 0 ; i < n ; i ++ ) ws[x[i] = r[i]] ++ ; for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1]; for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[x[i]]] = i; for(p = j = 1 ; p < n ; j <<= 1 , m = p) { for(p = 0 , i = n - j ; i < n ; i ++ ) y[p ++ ] = i; for(i = 0 ; i < n ; i ++ ) if(sa[i] - j >= 0) y[p ++ ] = sa[i] - j; for(i = 0 ; i < n ; i ++ ) wv[i] = x[y[i]]; for(i = 0 ; i < m ; i ++ ) ws[i] = 0; for(i = 0 ; i < n ; i ++ ) ws[wv[i]] ++ ; for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1]; for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[wv[i]]] = y[i]; for(swap(x , y) , x[sa[0]] = 0 , p = i = 1 ; i < n ; i ++ ) { if(y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) x[sa[i]] = p - 1; else x[sa[i]] = p ++ ; } } for(i = 1 ; i < n ; i ++ ) rank[sa[i]] = i; for(p = i = 0 ; i < n - 1 ; height[rank[i ++ ]] = p) for(p ? p -- : 0 , j = sa[rank[i] - 1] ; r[i + p] == r[j + p] ; p ++ ); } int query(int x , int y) { int k = log[y - x + 1]; return min(f[x][k] , f[y - (1 << k) + 1][k]); } int main() { int i , j , tmp; long long ans = 0; scanf("%d" , &n); for(i = 0 ; i < n ; i ++ ) scanf("%d" , &a[i].num) , a[i].p = i; sort(a , a + n , cmp); for(i = 0 ; i < n ; i ++ ) { if(a[i].num != v[m]) v[++m] = a[i].num; r[n - a[i].p - 1] = m; } n ++ , m ++ , da() , n -- ; for(i = 2 ; i <= n ; i ++ ) log[i] = log[i >> 1] + 1; for(i = 1 ; i <= n ; i ++ ) f[i][0] = height[i]; for(i = 1 ; i <= log[n] ; i ++ ) for(j = 1 ; j + (1 << i) - 1 <= n ; j ++ ) f[j][i] = min(f[j][i - 1] , f[j + (1 << (i - 1))][i - 1]); for(i = n - 1 ; i >= 0 ; i -- ) { tmp = 0; it = s.upper_bound(rank[i]); if(it != s.end()) tmp = max(tmp , query(rank[i] + 1 , *it)); if(it != s.begin()) tmp = max(tmp , query(*(--it) + 1 , rank[i])); ans += (long long)n - i - tmp; printf("%lld\n" , ans); s.insert(rank[i]); } return 0; }