CF1073G Yet Another LCP Problem

题意描述

洛谷 CodeForce

\(lcp(i,j)\) 表示 \(i\) 这个后缀和 \(j\) 这个后缀的最长公共前缀长度

给定一个字符串,每次询问的时候给出两个正整数集合 \(A\)\(B\),求

\(\sum_{i \in A,j \in B}lcp(i,j)\) 的值。

数据范围:\(n,q\leq 2\times 10^5, \displaystyle\sum_{i=1}^{q} k_i\leq 2\times 10^5,\displaystyle\sum_{i=1}^{q}l_i\leq 2\times 10^5\)

solution

后缀自动机加虚树。

对于两个后缀的 \(lcp\) 等价于原串的后缀树两个节点 \(lca\)\(max-len\)

然后我们的问题就转化为了求 \(\displaystyle\sum_{i\in A,j\in B} max-len(lca(i,j))\)

这个数据范围提示我们往虚树的那一方面去想。

我们可以对两个集合中每个后缀对应的节点建立一棵虚树,在虚树上进行dp。具体来说就是:

统计每个点作为 \(lca\) 出现的次数,然后算一下贡献即可。

转移式:ans += len[x]*num[x][1]*num[to][0], ans += len[x]*num[x][0]*num[x][1]

代码写起来比较好写,但调起来就不是一回事了。

update:一开始交了好几发都是 \(WA\), 调着调着发现原来是我虚树的板子出锅了,我那个板子会把一个点在虚树中添加多次,然后就算重了,关键是消耗战那道题这种写法还 \(tm\) 过了,唉害人不浅啊。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
#define LL long long
const int N = 5e5+10;
int n,m,tot,Siz,type,cnt1,cnt2,cnt,top,last,x;
int head[N],dep[N],Top[N],fa[N],siz[N],son[N],st[N],b[N],num[N][2],dfn[N],sta[N];
int link[N],len[N],tr[N][30];
LL ans;
char s[N];
vector<int> v[N];
struct node
{
    int to,net;
}e[N<<1];
inline int read()
{
    int s = 0, w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
void add(int x,int y)
{
    e[++tot].to = y;
    e[tot].net = head[x];
    head[x] = tot;
}
void Add(int x,int y)
{
    v[x].push_back(y);
}
bool comp(int a,int b)
{
	return dfn[a] < dfn[b];
}
void Extend(int ch)
{
    int now = ++Siz, p;
    len[now] = len[last]+1;
    for(p = last; p && !tr[p][ch]; p = link[p]) tr[p][ch] = now;
    if(!p) link[now] = 1;
    else
    {
        int x = tr[p][ch];
        if(len[x] == len[p] + 1) link[now] = x;
        else
        {
            int y = ++Siz;
            len[y] = len[p] + 1;
            memcpy(tr[y],tr[x],sizeof(tr[x]));
            link[y] = link[x];
            link[x] = link[now] = y;
            while(p && tr[p][ch] == x)
            {
                tr[p][ch] = y;
                p = link[p];
            }
        }
    }
    last = now;
}
void get_tree(int x)
{
    dep[x] = dep[fa[x]] + 1; siz[x] = 1;
    for(int i = head[x]; i; i = e[i].net)
    {
        int to = e[i].to;
        if(to == fa[x]) continue;
        fa[to] = x;
        get_tree(to);
        siz[x] += siz[to];
        if(siz[to] > siz[son[x]]) son[x] = to;
    }
}
void dfs(int x,int topp)
{
    Top[x] = topp; dfn[x] = ++type;
    if(son[x]) dfs(son[x],topp);
    for(int i = head[x]; i; i = e[i].net)
    {
        int to = e[i].to;
        if(to == fa[x] || to == son[x]) continue;
        dfs(to,to);
    }
}
int Lca(int x,int y)
{
    while(Top[x] != Top[y])
    {
        if(dep[Top[x]] < dep[Top[y]]) swap(x,y);
        x = fa[Top[x]];
    }
    return dep[x] <= dep[y] ? x : y;
}
void build()
{
	sta[++top] = 1;
	int num = unique(b+1,b+cnt+1)-b-1;
	for(int i = 1; i <= num; i++)
	{
		int x = b[i];
		int lca = Lca(x,sta[top]);
		while(top > 1 && dep[lca] <= dep[sta[top-1]]) Add(sta[top-1],sta[top]), top--;
		if(sta[top] != lca) Add(lca,sta[top]), sta[top] = lca;
		if(sta[top] != x) sta[++top] = x;
	}
	while(top > 1) Add(sta[top-1],sta[top]), top--;
}
void dp(int x,int fa)
{
	ans += 1LL * len[x] * num[x][0] * num[x][1]; 
	for(int i = 0; i < v[x].size(); i++)
    {
        int to = v[x][i];
        if(to == fa) continue;
        dp(to,x);
        ans += 1LL * len[x] * (1LL * num[x][0] * num[to][1] + 1LL * num[x][1] * num[to][0]);
        num[x][0] += num[to][0];
        num[x][1] += num[to][1];
    }
}
void QK(int x,int fa)
{
    num[x][0] = num[x][1] = 0; 
    for(int i = 0; i < v[x].size(); i++)
    {
        int to = v[x][i];
        if(to == fa) continue;
        QK(to,x);
    }
    v[x].clear();
}
int main()
{
    n = read(); m = read();
    scanf("%s",s+1); last = Siz = 1;
    for(int i = n; i >= 1; i--) Extend(s[i]-'a'), st[i] = last;
    for(int i = 2; i <= Siz; i++) add(link[i],i);
    get_tree(1); dfs(1,0);
    for(int i = 1; i <= m; i++)
    {
        cnt1 = read(); cnt2 = read(); cnt = top = 0;
        for(int j = 1; j <= cnt1; j++)
        {
            x = read();
            b[++cnt] = st[x];
            num[st[x]][0]++;
        }
        for(int j = 1; j <= cnt2; j++)
        {
            x = read();
            b[++cnt] = st[x];
            num[st[x]][1]++;
        }
        sort(b+1,b+cnt+1,comp);
        build();
        dp(1,1); QK(1,1);
        printf("%lld\n",ans);
        ans = 0;
    }
    return 0;
}

posted @ 2021-03-11 22:23  genshy  阅读(56)  评论(0编辑  收藏  举报