HDU5785 Interesting(Manacher + 延迟标记)
题目
Source
http://acm.hdu.edu.cn/showproblem.php?pid=5785
Description
Alice get a string S. She thinks palindrome string is interesting. Now she wanna know how many three tuple (i,j,k) satisfy 1≤i≤j<k≤length(S), S[i..j] and S[j+1..k] are all palindrome strings. It's easy for her. She wants to know the sum of i*k of all required three tuples. She can't solve it. So can you help her? The answer may be very large, please output the answer mod 1000000007.
A palindrome string is a string that is same when the string is read from left to right as when the string is read from right to left.
Input
The input contains multiple test cases.
Each test case contains one string. The length of string is between 1 and 1000000. String only contains lowercase letter.
Output
For each test case output the answer mod 1000000007.
Sample Input
aaa
abc
Sample Output
14
8
分析
题目大概说给一个字符串,找到其所有子串[i...k]满足它是由两个回文串拼成的,求Σi*k。
官方题解这么说的:
用manacher算法O(n)求出所有的回文半径。有了回文半径后,就可以求出cntL[i]表示以i结尾的回文串的起始位置的和cntR[i]表示以i起始的回文串的结尾位置的和,然后就可以求出答案了,这里要注意奇偶长度回文串的不同处理。复杂度O(n)。
本渣渣看了好久想了好久。。才反应过来x1*y1+x1*y2+x2*y1+x2*y2=(x1+x2)*(y1+y2) 。。这个真反应不过来= =太渣了。。那么只要求出cntL[i]和cntR[i],Σ(cntL[i]*cntR[i+1])就是要的答案了。
本渣渣又看了好久想了好久。。才想到怎么在O(n)求出cntL[i]和cntR[i]。。
- 首先,跑一下Manacher就能知道每个位置向左和向右最多能延伸的长度
比如中间位置是i(回文串是奇数的情况,偶数同理。另外先不管跑Manacher前插入的特殊字符,因为也同理。。),p[i]表示Manacher求得的延伸半径。
- 那么[i-p[i]+1, i+p[i]-1]就是回文串,故cntL[i+p[i]-1]就该加上i-p[i]+1;
- 而[i-p[i]+2, i+p[i]-2]也是回文串,cntL[i+p[i]-2]就该加上i-p[i]+2…………
总之可以知道这个就相当于,对于cntL数组要在[i, i+p[i]-1]区间上依次加上一个末项为i-p[i]+1且公差为-1的等差数列,cntR也是同理的这儿同样就不说了。
- 那么现在问题就是怎么对区间更新,让区间[L,R]加上一个首项k公差-1的等差数列,这些更新操作完成后要进行单点的查询,而且整个的时间复杂度要为O(n)。
区间更新自然就联想到延迟标记。即,更新操作就用O(1)打打标记,所有更新操作完成后,从左往右O(n)遍历过去,把标记传递过去,并更新真实的值。
- 那么打什么标记?
首先能想到有一个加标记sumtag,表示这个数需要加多少。对于让区间[L,R]加上一个首项k公差-1的等差数列,就让sumtag[L]+=k,然后传递标记中访问到L时把sumtag[L]加到真实的值里,并把标记下传,即sumtag[L+1]+=sumtag[L]-1。不过考虑到多个区间修改有叠加,那么应该还要开几个数组记录各个点有多少个更新操作,sumcnt,那么下传也就要改成sumtag[L+1]+=sumtag[L]-sumcnt[L],另外sumcnt也要传递下去,即sumcnt[L+1]+=sumcnt[L]。
不过这样还不够,因为这么传递会一直传递到数组尾巴,也就是说相当于更新了[L,MAXN]区间,而我们要的是[L,R]区间的更新,多更新了[R+1,MAXN]这个区间,而这个区间显然也是加上了一个等差数列!所以,考虑再用一个标记,减标记subtag,表示区间需要减的数,对于多更新的只要在subtag[R+1]这儿打上个标记就OK了!这个与sumtag是一样的,最后遍历标记下传的同时减去subtag的值即可,当然也需要subcnt。
于是这题就能这么解决了。。
代码
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define MAXN 1111111 char s[MAXN<<1]; int p[MAXN<<1]; void manacher(){ int mx=0,id; for(int i=1;s[i];++i){ if(mx>i) p[i]=min(p[2*id-i],mx-i); else p[i]=1; for(;s[i+p[i]]==s[i-p[i]];++p[i]); if(p[i]+i>mx){ mx=p[i]+i; id=i; } } } char str[MAXN]; long long val[2][MAXN]; long long addtag[2][MAXN],addcnt[2][MAXN],subtag[2][MAXN],subcnt[2][MAXN]; void update(int flag,int l,int r,int k){ addtag[flag][l]+=k; addtag[flag][l]%=1000000007; ++addcnt[flag][l]; subtag[flag][r+1]+=k-r+l; subtag[flag][r+1]%=1000000007; ++subcnt[flag][r]; } void pushdown(){ for(int i=1; str[i]; ++i){ if(addcnt[0][i]){ val[0][i]+=addtag[0][i]; val[0][i]%=1000000007; addtag[0][i+1]+=addtag[0][i]-addcnt[0][i]; addtag[0][i+1]%=1000000007; addcnt[0][i+1]+=addcnt[0][i]; } if(subcnt[0][i]){ val[0][i]-=subtag[0][i]; val[0][i]%=1000000007; subtag[0][i+1]+=subtag[0][i]-subcnt[0][i]; subtag[0][i+1]%=1000000007; subcnt[0][i+1]+=subcnt[0][i]; } if(addcnt[1][i]){ val[1][i]+=addtag[1][i]; val[1][i]%=1000000007; addtag[1][i+1]+=addtag[1][i]-addcnt[1][i]; addtag[1][i+1]%=1000000007; addcnt[1][i+1]+=addcnt[1][i]; } if(subcnt[1][i]){ val[1][i]-=subtag[1][i]; val[1][i]%=1000000007; subtag[1][i+1]+=subtag[1][i]-subcnt[1][i]; subtag[1][i+1]%=1000000007; subcnt[1][i+1]+=subcnt[1][i]; } } } int main(){ while(~scanf("%s",str+1)){ s[0]='$'; int i=1; for(;str[i];++i){ s[(i<<1)-1]='#'; s[i<<1]=str[i]; } s[(i<<1)-1]='#'; s[i<<1]=0; manacher(); memset(val,0,sizeof(val)); memset(addtag,0,sizeof(addtag)); memset(addcnt,0,sizeof(addcnt)); memset(subtag,0,sizeof(subtag)); memset(subcnt,0,sizeof(subcnt)); for(int i=1; s[i]; ++i){ if(i&1){ if(p[i]/2==0) continue; update(0,i/2-p[i]/2+1,i/2,i/2+p[i]/2); update(1,i/2+1,i/2+p[i]/2,i/2); }else{ update(0,i/2-p[i]/2+1,i/2,i/2+p[i]/2-1); update(1,i/2,i/2+p[i]/2-1,i/2); } } pushdown(); long long ans=0; for(int i=1; str[i]&&str[i+1]; ++i){ ans+=val[1][i]*val[0][i+1]; ans%=1000000007; } if(ans<0) ans+=1000000007; printf("%I64d\n",ans); } return 0; }