【bzoj3879】SvT 后缀数组+倍增RMQ+单调栈
题目描述
(我并不想告诉你题目名字是什么鬼)
有一个长度为n的仅包含小写字母的字符串S,下标范围为[1,n].
现在有若干组询问,对于每一个询问,我们给出若干个后缀(以其在S中出现的起始位置来表示),求这些后缀两两之间的LCP(LongestCommonPrefix)的长度之和.一对后缀之间的LCP长度仅统计一遍.
输入
第一行两个正整数n,m,分别表示S的长度以及询问的次数.
接下来一行有一个字符串S.
接下来有m组询问,对于每一组询问,均按照以下格式在一行内给出:
首先是一个整数t,表示共有多少个后缀.接下来t个整数分别表示t个后缀在字符串S中的出现位置.
输出
样例输入
7 3
popoqqq
1 4
2 3 5
4 1 2 5 6
样例输出
0
0
2
题解
后缀数组+倍增RMQ+单调栈
首先预处理出sa和height数组。
然后对于每组询问,将要求的后缀去重后按照rank从小到大排序。
由于我们有:LCP(a,c)=min(LCP(a,b),LCP(b,c)),其中rank[a]<rank[b]<rank[c]
所以我们只需要知道相邻两个要求的后缀之间的LCP,即可推出任意两个后缀的LCP。
这里求LCP的方式是倍增RMQ,所以我偷改了height的定义:height[i][j]表示排名为i-2^j的后缀与排名为i的后缀的LCP。
这样转化成了一个新的问题:给你n个数,求其每个子区间中最小值的和。
考虑对答案的贡献:ai对答案的贡献是满足l∈[lpos,i],r∈[i,rpos]的所有区间[l,r],也即ai*(i-lpos+1)*(rpos-i+1),其中lpos是i左侧最后一个大于i的,rpos是i右侧最后一个大于等于i的。
(左右包含等号的情况不同是为了处理相同的数,防止重复或漏算)
可以用一个单调栈来在线性时间内求出i-lpos+1和rpos-i+1,具体方法见代码。
最后的最后,需要把用于去重的数组vis清零,注意不能用memset。
#include <cstdio> #include <cstring> #include <algorithm> #define N 500010 #define mod 23333333333333333ll using namespace std; int n , m , sa[N] , r[N] , ws[N] , wa[N] , wb[N] , wv[N] , rank[N] , height[N][21] , log[N] , num[N * 6] , vis[N] , pos[N] , val[N] , sta[N] , top , lp[N] , rp[N]; char str[N]; void da() { int i , j , p , *x = wa , *y = wb; for(i = 0 ; i < m ; i ++ ) ws[i] = 0; for(i = 0 ; i < n ; i ++ ) ws[x[i] = r[i]] ++ ; for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1]; for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[x[i]]] = i; for(p = j = 1 ; p < n ; j <<= 1 , m = p) { for(p = 0 , i = n - j ; i < n ; i ++ ) y[p ++ ] = i; for(i = 0 ; i < n ; i ++ ) if(sa[i] - j >= 0) y[p ++ ] = sa[i] - j; for(i = 0 ; i < n ; i ++ ) wv[i] = x[y[i]]; for(i = 0 ; i < m ; i ++ ) ws[i] = 0; for(i = 0 ; i < n ; i ++ ) ws[wv[i]] ++ ; for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1]; for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[wv[i]]] = y[i]; for(swap(x , y) , x[sa[0]] = 0 , p = i = 1 ; i < n ; i ++ ) { if(y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) x[sa[i]] = p - 1; else x[sa[i]] = p ++ ; } } for(i = 1 ; i < n ; i ++ ) rank[sa[i]] = i; for(p = i = 0 ; i < n - 1 ; height[rank[i ++ ]][0] = p) for(p ? p -- : 0 , j = sa[rank[i] - 1] ; r[i + p] == r[j + p] ; p ++ ); } int query(int x , int y) { x ++ ; int k = log[y - x + 1]; return min(height[x + (1 << k) - 1][k] , height[y][k]); } bool cmp(int a , int b) { return rank[a] < rank[b]; } int main() { int i , j , k , cnt , tot; long long ans; scanf("%d%d%s" , &n , &k , str); for(i = 0 ; i < n ; i ++ ) r[i] = str[i] - 'a' + 1; n ++ , m = 28 , da() , n -- ; for(i = 2 ; i <= n ; i ++ ) log[i] = log[i >> 1] + 1; for(i = 1 ; i <= log[n] ; i ++ ) for(j = (1 << i) ; j <= n ; j ++ ) height[j][i] = min(height[j][i - 1] , height[j - (1 << (i - 1))][i - 1]); while(k -- ) { scanf("%d" , &cnt); tot = 0 , ans = 0; for(i = 1 ; i <= cnt ; i ++ ) { scanf("%d" , &num[i]) , num[i] -- ; if(!vis[num[i]]) vis[num[i]] = 1 , pos[++tot] = num[i]; } sort(pos + 1 , pos + tot + 1 , cmp); for(i = 1 ; i < tot ; i ++ ) val[i] = query(rank[pos[i]] , rank[pos[i + 1]]); sta[0] = top = 0; for(i = 1 ; i < tot ; i ++ ) { while(top && val[sta[top]] > val[i]) top -- ; lp[i] = i - sta[top] , sta[++top] = i; } sta[0] = tot , top = 0; for(i = tot - 1 ; i ; i -- ) { while(top && val[sta[top]] >= val[i]) top -- ; rp[i] = sta[top] - i , sta[++top] = i; } for(i = 1 ; i < tot ; i ++ ) ans = (ans + (long long)lp[i] * rp[i] * val[i]) % mod; printf("%lld\n" , ans); for(i = 1 ; i <= cnt ; i ++ ) vis[num[i]] = 0; } return 0; }