Palisection CodeForces - 17E (回文树)
题意:给一个字符串,然后求相交的回文串的对数。
思路:求出总对数,然后利用num[i]求出不相交的对数减去即可。
顺便学了mod非质数下的逆元求法,要用扩展欧几里得求,但是也有限制要分母与mod互质才行
#include<bits/stdc++.h>
using namespace std;
#define ls rt<<1
#define rs (rt<<1)+1
#define ll long long
#define fuck(x) cout<<#x<<" "<<x<<endl;
const int maxn=2e6+10;
const ll mod=51123987;
typedef pair<char,int> pci;
int d[4][2]={1,0,-1,0,0,1,0,-1};
char s[maxn];
int pre[maxn];
const int MAXN = 2e6+10;//长度
const int N = 26;//字符集大小
struct Palindromic_Tree {
vector<pci>next[MAXN];//next指针,next指针和字典树类似,指向的回文子串为i节点对应的回文子串两端加上同一个字符ch构成
int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点,fail指针指向的是i节点对应的回文子串的最长后缀回文子串(是真后缀),这个匹配过程与kmp有点类似,fail[i]表示节点i失配以后跳转到长度小于该串且以该节点表示回文串的最后一个字符结尾的最长回文串表示的节点
int cnt[MAXN];//在调用count函数之后,cnt[i]表示i节点对应的回文子串的出现次数的准确值
int num[MAXN];//在调用add函数之后返回num[last]可以得到以i位置的字符为尾的回文串个数
int len[MAXN];//len[i]表示节点i表示的回文串的长度
int S[MAXN];//存放添加的字符,
int last;//指向上一个字符所在的节点,方便下一次add
int n;//字符数组指针,从1开始,到n结束
int p;//节点指针,0指向偶根,1指向奇根,有效编号到p-1
int newnode(int l) {//新建节点
next[p].clear();
cnt[p] = 0;
num[p] = 0;
len[p] = l;
fail[p]=0;
return p++;
}
void init() {//初始化
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}
int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
while (S[n - len[x] - 1] != S[n]) x = fail[x];
return x;
}
int add(int c) {
c -= 'a';
S[++n] = c;
int cur = get_fail(last),tmp=-1;//通过上一个回文串找这个回文串的匹配位置
for(int i=0;i<next[cur].size();i++) if(next[cur][i].first==c+'a') {tmp=next[cur][i].second;break;}
if (tmp==-1) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode(len[cur] + 2);//新建节点
tmp=get_fail(fail[cur]);
for(int i=0;i<next[tmp].size();i++) if(next[tmp][i].first==c+'a') {fail[now]=next[tmp][i].second;break;}
next[cur].push_back(make_pair(c+'a',now));
num[now] = num[fail[now]] + 1;
tmp=now;
}
last = tmp;
cnt[last]++;
return num[last];
}
void count() {
for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
}pt;
void exgcd(ll a,ll b,ll &x,ll &y,ll &d)
{
if(!b) {
x=1,y=0,d=a;
return;
}
ll xx,yy;
exgcd(b,a%b,xx,yy,d);
x=yy,y=xx-(a/b)*yy;
}
//a在模m意义下的逆元 要求a,m互质 O(logn)
ll inverse(ll a,ll m)
{
ll x,y,d;
exgcd(a,m,x,y,d);//ax+py=1
if(d!=1) return -1;
return (x%m+m)%m;
}
int main(){
int len;
while(scanf("%d",&len)!=EOF) {
scanf("%s",s+1);
ll ans = 0;
pt.init();
for (int i = 1; i <= len; i++)
pre[i] = (pre[i - 1] + pt.add(s[i]))%mod;
pt.count();
for(int i=2;i<=pt.p-1;i++) ans+=pt.cnt[i];
ans=((ans%mod)*((ans-1)%mod)%mod)*inverse(2,mod)%mod;
pt.init();
int tmp=0;
for (int i = len; i >= 1; i--)
ans = (ans-1LL * pre[i - 1] * pt.add(s[i])%mod+mod)%mod;
printf("%I64d\n", ans);
}
return 0;
}
要用扩展欧几里得求