[HAOI2016]找相同字符

传送门

两个串嘛……可以建广义后缀自动机。

我们每次要记录一下对于每个节点,其对应的在第一个串上的size和第二个串上的size,那么每个节点对于答案的贡献就是\(size[0] * size[1] * (l[i] - l[fa[i]])\)

解释一下,size其实表示的就是endpos集合之内的元素个数 ,也就是不同的位置,后面的差值是节点对应的后缀长度区间,所以他们的乘积就是答案了。

然后这题要注意!你在建立SAM的时候,可以每次把新加入节点的size设为1,但是继续建立广义后缀自动机的时候,因为公用节点的关系,所以我们每次在记录size的时候,需要从开始的地方往后跳转移,把转移到的位置的size设为1,这样建立就可以了。

#include<bits/stdc++.h>
#define rep(i,a,n) for(register int i = a;i <= n;i++)
#define per(i,n,a) for(register int i = n;i >= a;i--)
#define enter putchar('\n')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
using namespace std;
typedef long long ll;
const int M = 200005;
const int N = 10000005;
const int INF = 2147483647;

int read()
{
   int ans = 0,op = 1;
   char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
   while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
   return ans * op;
}

char s[M];
int n1,n2,a[M<<2],c[M<<2];
ll ans;

struct Suffix
{
   int last,cnt,ch[M<<2][26],fa[M<<2],size[M<<2][2],l[M<<2];
   void extend(int c,int f)
   {
      int p = last,np = ++cnt;
      last = cnt,l[np] = l[p] + 1;
      if(!f) size[np][f] = 1;
      while(p && !ch[p][c]) ch[p][c] = np,p = fa[p];
      if(!p) {fa[np] = 1;return;}
      int q = ch[p][c];
      if(l[q] == l[p] + 1) fa[np] = q;
      else
      {
     int nq = ++cnt;
     l[nq] = l[p] + 1,memcpy(ch[nq],ch[q],sizeof(ch[q]));
     fa[nq] = fa[q],fa[q] = fa[np] = nq;
     while(ch[p][c] == q) ch[p][c] = nq,p = fa[p];
      }
   }
   void cal()
   {
      rep(i,1,cnt) c[l[i]]++;
      rep(i,1,cnt) c[i] += c[i-1];
      rep(i,1,cnt) a[c[l[i]]--] = i;
      per(i,cnt,1)
      {
     int p = a[i];
     size[fa[p]][0] += size[p][0],size[fa[p]][1] += size[p][1];
     //ans += 1ll * size[p][0] * size[p][1] * (l[p] - l[fa[p]] + 1);
      }
      rep(i,1,cnt) ans += 1ll * size[i][0] * size[i][1] * (l[i] - l[fa[i]]);
      printf("%lld\n",ans);
   }
}SAM;

int main()
{
   SAM.cnt = SAM.last = 1;
   scanf("%s",s+1),n1 = strlen(s+1);
   rep(i,1,n1) SAM.extend(s[i] - 'a',0);
   scanf("%s",s+1),n2 = strlen(s+1),SAM.last = 1;
   int p = 1;
   rep(i,1,n2)
   {
      SAM.extend(s[i] - 'a',1);
      p = SAM.ch[p][s[i]-'a'],SAM.size[p][1] = 1;
   }
   SAM.cal();
   return 0;
}

posted @ 2019-01-12 22:06  CaptainLi  阅读(131)  评论(0编辑  收藏  举报