Sum of Squares of the Occurrence Counts解题报告(后缀自动机+LinkCutTree+线段树思想)
题目描述
给定字符串\(S(|S|\le10^5)\),对其每个前缀求出如下的统计量:
对该字符串中的所有子串,统计其出现的次数,求其平方和。
Sample Input:
aaa
Sample Output:
1 5 14
详细题解
1、只求整个串的答案的解决方案
首先可一眼想到后缀自动机。
对后缀自动机上每个状态,定义endpos为所有能走到该状态的子串中子串右端点的取值集合。如果求出其endpos位置个数\(x\),那么就能求得该状态对答案的贡献,为\(x^2*(r-l+1)\),其中\(l,r\)分别为该状态的最小和最大子串长度。\(l,r\)直接由状态的len得到。考虑每个结点endpos如何求。
定理:字符串\(S\)建出的自动机上,\(S\)的每个前缀走到的结点一定是非拷贝结点。
证明:当建出前缀\(S_i\)的自动机时,显然串\(S_i\)走到的结点是非拷贝结点。该结点的len值为\(|S_i|\)。若该结点被拷贝,由后缀自动机的性质,拷贝结点的len值严格小于原先结点的len值,因此串\(|S_i|\)走到的结点在之后不可能变化到拷贝结点,否则将与拷贝结点的len值将大于等于\(|S_i|\)。矛盾,证毕。
定理:字符串\(S\)建出的自动机上,将\(S\)的每个前缀走到的结点标记为1后,每个状态的endpos大小就是该状态对应自动机fail树子树中标记为1的结点个数。
证明:考虑一个子串\(s\)到达某状态\(ST\),那么必然可以将\(s\)左侧添加字符串\(t\)使得\(ts\)是\(S\)的一个前缀。那么\(ts\)所到达的自动机状态必然在\(ST\)状态的fail树子树中(由后缀自动机定义)。因此所有endpos位置都对应了fail树子树中一个标记为1的结点。另一方面,对两个不同的标记为1的结点,对应的endpos位置一定不同。因此fail树子树中标记为1的结点个数不重不漏的对应了endpos集合大小。证毕。
根据以上两个定理,建出自动机后,dfs fail树即可在线性时间内求出答案。
2、转化为数据结构问题
下面考虑当自动机添加一个字符时如何更新答案,这只需考虑哪些状态的endpos个数变多以及哪些状态\(l,r\)改变。
(1)endpos变化:
情况1:没有拷贝结点,则fail树中新的last结点直接指向某结点\(j\),last沿fail指针一路走上去的这些结点endpos大小加1。
情况2:有拷贝结点,则结点\(j\)子树被断开,\(j\)和新的last指向copy,copy指向\(j\)原来的link。仍然是last沿fail指针一路走上去的结点endpos+1。此外,copy结点具有\(j\)结点的所有endpos。
(2)\(l,r\)变化:
首先要考虑新的last以及copy这两个新结点。其次,当结点\(j\)子树被断开,然后\(j\)的link指向copy结点后,\(j\)的\(l,r\)区间会变化。除此之外没有别的结点。
于是问题转化为一条树链endpos大小加1,单点修改\(r,l\)的值,询问所有状态的\(endpos^2 \times (r-l+1)\)的和(称其为endpos的加权平方和,\(r-l+1\)为权)。又由于需要支持树的断开和合并,因此必须使用LinkCutTree。
3、问题的解决
首先我们需要考虑区间加1,询问全局平方和如何做。这是经典线段树题目,需要维护区间和,并用区间和更新区间平方和即可。
现在转到树中,只需要把线段树换成splay。splay中维护endpos值、\(r-l+1\)值(即权)、权的和、endpos的加权和以及加权平方和。
当链上加1后,endpos的加权和增量是权的和,而加权平方和增量可用endpos加权和以及权的和表示。
当修改\(j\)结点的\(r-l+1\)时,只需将\(j\)对应splay旋转到根,然后可直接对该结点访问或修改(不会影响其他结点)。
最后考虑如何求全局和。对于树链修改,只需记录修改前后树链增量即可。对于修改\(j\)结点的\(r-l+1\)以及添加copy结点,这合起来不会改变答案。
因此总时间复杂度\(O(n(|\sum|+\log n)\)。
核心代码
由于此题代码过长,在此略去LCT部分模板,值保留关键部分的代码供理解算法思想。代码如下:
1 #define LL long long 2 struct Tree{ 3 Tree *left, *right, *fa; 4 int endpos, len, delta; 5 LL sumLen, sum, sum2; 6 }lct[MAXN]; 7 int num; 8 inline Tree* newNode(int endpos, int len){ 9 lct[++num] = { lct, lct, lct, endpos, len, 0, len, (LL)endpos * len, (LL)endpos * endpos * len }; 10 return &lct[num]; 11 } 12 inline void pushUp(Tree * rt){ 13 Tree *t1 = rt->left, *t2 = rt->right; 14 LL endpos = rt->endpos; 15 rt->sumLen = rt->len + t1->sumLen + t2->sumLen; 16 rt->sum = endpos * rt->len + t1->sum + t2->sum; 17 rt->sum2 = endpos * endpos * rt->len + t1->sum2 + t2->sum2; 18 } 19 inline void update(Tree *t, int delta){ 20 t->delta += delta; 21 t->endpos += delta; 22 t->sum2 += 2 * delta*t->sum + t->sumLen*delta*delta; 23 t->sum += t->sumLen*delta; 24 } 25 inline void pushDown(Tree *rt){ 26 if (rt->delta){ 27 if (rt->left != lct)update(rt->left, rt->delta); 28 if (rt->right != lct)update(rt->right, rt->delta); 29 rt->delta = 0; 30 } 31 } 32 LL add(Tree *rt) 33 { 34 Tree *t = access(rt); 35 LL val = t->sum2; 36 update(t, 1); 37 return t->sum2 - val; 38 } 39 void changeLen(Tree *rt, int len) 40 { 41 LL endpos = rt->endpos, t = len - rt->len; 42 rt->sum2 += t * endpos * endpos; 43 rt->sum += t * endpos; 44 rt->sumLen += t; 45 rt->len = len; 46 } 47 LL add(char ch) 48 { 49 int c = convert(ch); 50 int cur = ++cnt, i; 51 st[cur].len = st[last].len + 1; 52 memset(st[cur].next, 0, sizeof(st[cur].next)); 53 for (i = last; i != -1 && !st[i].next[c]; i = st[i].link) 54 st[i].next[c] = cur; 55 if (i == -1){ 56 st[cur].link = 0; 57 newNode(0, st[cur].len); 58 } 59 else{ 60 int j = st[i].next[c]; 61 if (st[i].len + 1 == st[j].len){ 62 st[cur].link = j; 63 newNode(0, st[cur].len - st[j].len); 64 link(&lct[cur], &lct[j]); 65 } 66 else{ 67 int copy = ++cnt; 68 st[copy].len = st[i].len + 1; 69 memcpy(st[copy].next, st[j].next, sizeof(st[j].next)); 70 st[copy].link = st[j].link; 71 for (; i != -1 && st[i].next[c] == j; i = st[i].link) 72 st[i].next[c] = copy; 73 st[j].link = st[cur].link = copy; 74 cut(&lct[j]); 75 //cut后lct[j]是splay的根,对元素的修改和访问可直接进行 76 changeLen(&lct[j], st[j].len - st[copy].len); 77 newNode(0, st[cur].len - st[copy].len); 78 newNode(lct[j].endpos, st[copy].len - st[st[copy].link].len); 79 link(&lct[j], &lct[copy]); 80 link(&lct[cur], &lct[copy]); 81 link(&lct[copy], &lct[st[copy].link]); 82 } 83 } 84 last = cur; 85 return add(&lct[cur]); 86 } 87 char s[100001]; 88 int main() 89 { 90 scanf("%s", s); 91 LL ans = 0; 92 init(); 93 for (int i = 0; s[i]; i++){ 94 ans += add(s[i]); 95 printf("%lld\n", ans); 96 } 97 }