[NOI2018]你的名字

嘟嘟嘟


这题以前写过弃掉了,后来竟然连自己的68分写法都看不懂了……
这次回首这道题,心想怎么说也得把这题切了,哪怕抄题解也行。
但没想到别人的题解自己怎么也看不懂,最终还是自己搞出来了(我真nb)。
总用时前一天下午到第二天凌晨0:30+第二天半个上午。


我们先来回顾\(L = 1, R = n\)的情况。
大体思路就是求出本质不同的公共子串数目,然后用\(T\)串的本质不同的子串减去即可。
\(T\)的本质不同的子串个数大家都会求,就不在这唠叨了。我们重点看前面的怎么求。
我们在\(T\)的后缀自动机中插入一个字符\(t[i]\)之后,马上记录\(t[i]\)所在结点的父亲的长度,即\(len[link[pos[i]]]\)。这个是什么东西呢?记前面那个值为\(Len\),他表示的是左端点为\(1\)~\(Len - 1\),右端点为\(i\)的这些子串第一次的出现位置都是\(i\)。因为新建的这个结点的endpos必定只有这个\(i\),那么属于这个结点的所有子串的出现位置也就只有\(i\)了。
注意,一定是插入完后马上统计。如果把后缀自动机建完后再一个个统计就不对了。因为有的结点的父亲已经改变。
我们记上面求出的这个数组为\(ha[i]\)


求出了上面这个东西后就好办了。我们把\(T\)放在\(S\)的后缀自动机上跑,记\(T\)到位置\(i\)时的匹配长度为\(l\),那么如果\(l >ha[i]\)的话,就表示在在\(T\)的第\(i\)个位置,本质不同的公共子串数目增加了\(l - ha[i]\),把他加到答案里即可。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s[maxn], t[maxn];
int ha[maxn];
struct Sam
{
  int cnt, las;
  int tra[maxn << 1][27], len[maxn << 1], link[maxn << 1];
  In void init()
  {
    link[cnt = las = 0] = -1;
    Mem(tra[0], 0);
  }
  In void insert(int c)
  {
    int now = ++cnt, p = las; Mem(tra[cnt], 0);
    len[now] = len[las] + 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
        int q = tra[p][c];
        if(len[q] == len[p] + 1) link[now] = q;
    	else
      	{
          int clo = ++cnt; Mem(tra[cnt], 0);
          memcpy(tra[clo], tra[q], sizeof(tra[q]));
          len[clo] = len[p] + 1;
          link[clo] = link[q], link[q] = link[now] = clo;
          while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
        }
      }
    las = now;
  }
  int pos[maxn << 1], buc[maxn << 1];
  In ll lcs(char* s)
  {
    ll ret = 0;
    int n = strlen(s);
    for(int i = 0, l = 0, p = 0; i < n; ++i)
      {
	    int c = s[i] - 'a';
	    while(~p && !tra[p][c]) p = link[p], l = len[p];
	    if(p == -1) p = l = 0;
	    else ++l, p = tra[p][c];
	    if(l > ha[i]) ret += l - ha[i];
      }
    return ret;
  }
}S, T;

In void work0()
{
  int len = strlen(t);
  ll tp = 0;
  for(int i = 0; i < len; ++i) tp += i + 1 - ha[i];    //根据ha[i]的定义,我们也可以这么求本质不同的子串个数
  write(tp - S.lcs(t)), enter;
}

int main()
{
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  int m = read();
  for(int i = 1; i <= m; ++i)
    {
      scanf("%s", t); int len = strlen(t);
      T.init();
      for(int j = 0; j < len; ++j) T.insert(t[j] - 'a'), ha[j] = T.len[T.link[T.las]];
      int L = read(), R = read();
      if(L == 1 && R == n) {work0(); continue;}
    }
  return 0;
}

当当当当!下面开始讲正解! 现在有了$L, R$的限制。一个直观的想法就是线段树合并求出每一个结点的endpos集合。然后匹配的时候除了满足这个节点有字符$c$的转移,还要保证转移后到达的结点的endpos有在$[L, R]$里面的。 这样一直跳,如果他最终在位置$i$时的匹配长度为$l$,我们要考虑从endpos向前$l$长度的字符串超没超出$L$的限制,即$endpos - l$是否小于$L$,如果小于,那么匹配的长度实际是$endpos - L + 1$。然后看这个长度是否大于$ha[i]$,如果大于,就减去$ha[i]$加到答案里。 那么为了让匹配的长度尽量不超过下界$L$,我们要找的是在$[L, R]$区间内尽量靠右的endpos。这个线段树维护最大值即可。
表面上看这么做似乎就万事大吉了。但实际上68分以后照样WA。 为啥咧? 因为我们查找endpos的时候,是在能匹配的最长长度的结点上的线段树查询的,长度越长,endpos就越少,也就说明,我们在较长的匹配上查到的endpos减去$l$后会超过$L$的限制,而且可能超出很多,导致剩下的部分比$ha[i]$小。但是如果我们在匹配长度较短的线段树上查找的话,反而能找到一个endpos,在$L$的限制之内匹配的长度比$ha[i]$大,或者所有的匹配长度干脆都不会超出$L$。这时候就需要把这些长度加到答案里。 因此,匹配的时候,我们不仅要在这个结点的线段树上查找endpos计算贡献。还要到他的祖先结点上执行同样的过程,直到匹配长度不受$L$的限制再退出。接着去匹配下一位。
自认为码风还是十分工整的。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e5 + 5;
const int maxt = 4e7 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
  freopen("name.in", "r", stdin);
  freopen("name.out", "w", stdout);
#endif
}

int n, L, R;
char s[maxn], t[maxn];
int ha[maxn];

struct Tree
{
  int ls, rs;
  int rpos;
}seg[maxt];
int root[maxn << 1], cur[maxn], tcnt = 0;
#define LS seg[now].ls
#define RS seg[now].rs
In void update(int& now, int l, int r, int id)
{
  now = ++tcnt;
  seg[now].rpos = -1;
  if(l == r) {seg[now].rpos = l; return;}
  int mid = (l + r) >> 1;
  if(id <= mid) update(LS, l, mid, id);
  else update(RS, mid + 1, r, id);
  seg[now].rpos = max(seg[LS].rpos, seg[RS].rpos);
}
In int merge(int x, int y, int l, int r)
{
  if(!x || !y) return x | y;
  if(l == r) {seg[x].rpos |= seg[y].rpos; return x;}
  int mid = (l + r) >> 1, z = ++tcnt;
  seg[z].ls = merge(seg[x].ls, seg[y].ls, l, mid);
  seg[z].rs = merge(seg[x].rs, seg[y].rs, mid + 1, r);
  seg[z].rpos = max(seg[x].rpos, seg[y].rpos);
  return z;
}
In int query(int now, int l, int r, int L, int R)
{
  if(!now) return -1;
  if(l == L && r == R) return seg[now].rpos;
  int mid = (l + r) >> 1;
  if(R <= mid) return query(LS, l, mid, L, R);
  else if(L > mid) return query(RS, mid + 1, r, L, R);
  else return max(query(LS, l, mid, L, mid), query(RS, mid + 1, r, mid + 1, R));
}

struct Sam
{
  int cnt, las;
  int tra[maxn << 1][27], len[maxn << 1], link[maxn << 1];
  In void init()
  {
    link[cnt = las = 0] = -1;
    Mem(tra[0], 0);
  }
  In void insert(int c, int id)
  {
    int now = ++cnt, p = las; Mem(tra[cnt], 0);
    len[now] = len[las] + 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
		int q = tra[p][c];
		if(len[q] == len[p] + 1) link[now] = q;
		else
		  {
		    int clo = ++cnt; Mem(tra[cnt], 0);
		    memcpy(tra[clo], tra[q], sizeof(tra[q]));
		    len[clo] = len[p] + 1;
		    link[clo] = link[q], link[q] = link[now] = clo;
		    while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
		  }
      }
    las = now;
    if(~id) cur[id] = now;
  }
  int pos[maxn << 1], buc[maxn << 1];
  In void solve()
  {
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    for(int i = cnt; i; --i)
      {
		int now = pos[i], fa = link[now];
		root[fa] = merge(root[fa], root[now], 0, n - 1);
      }
  }
  In ll calc(char* s)
  {
    ll ret = 0;
    int m = strlen(s);
    for(int i = 0, l = 0, p = 0, pos = 0; i < m; ++i)
      {
		int c = s[i] - 'a';
		while(~p && (!tra[p][c] || (pos = query(root[tra[p][c]], 0, n - 1, L, R)) == -1))
		  p = link[p], l = len[p];	//没有转移边或者[L, R]内没有endpos,就一直失配 
		if(p == -1) p = l = 0;
		else
		  {
		    ++l, p = tra[p][c];
		    int tp = min(l, pos - L + 1);
		    if(tp > ha[i]) ret += tp - ha[i];
		    if(l > pos - L + 1)	//匹配长度仍受L的限制 
		      {
				int tp3 = max(tp, ha[i]), q = link[p], l2 = len[q];
				while(~q)
				  {
				    int pos2 = query(root[q], 0, n - 1, L, R), tp2 = min(l2, pos2 - L + 1);
				    if(tp2 > tp3) ret += tp2 - tp3;
				    else if(l2 <= pos2 - L + 1) break;	//匹配长度不受L的限制了,退出 
				    q = link[q], l2 = len[q], tp3 = max(tp2, ha[i]);
				  }
		      }
		  }
      }
    return ret;
  }
}S, T;

int main()
{
  MYFILE();
  scanf("%s", s);
  n = strlen(s); S.init();
  for(int i = 0; i < n; ++i)
    {
      S.insert(s[i] - 'a', i);
      update(root[cur[i]], 0, n - 1, i);
    }
  S.solve();
  int m = read();
  for(int i = 1; i <= m; ++i)
    {
      scanf("%s", t); int len = strlen(t);
      T.init();
      for(int j = 0; j < len; ++j) T.insert(t[j] - 'a', -1), ha[j] = T.len[T.link[T.las]];
      L = read() - 1, R = read() - 1;
      ll tp = 0;
      for(int j = 0; j < len; ++j) tp += j + 1 - ha[j];
      write(tp - S.calc(t)), enter;
    }
  return 0;
}
posted @ 2019-05-29 22:43  mrclr  阅读(641)  评论(0编辑  收藏  举报