P4248 [AHOI2013] 差异

传送门

题目描述

给定一个长度为 \(n\) 的字符串 \(S\),令 \(T_i\) 表示它从第 \(i\) 个字符开始的后缀。求

\[\displaystyle \sum_{1\leqslant i<j\leqslant n}\operatorname{len}(T_i)+\operatorname{len}(T_j)-2\times\operatorname{lcp}(T_i,T_j) \]

其中,\(\text{len}(a)\) 表示字符串 \(a\) 的长度,\(\text{lcp}(a,b)\) 表示字符串 \(a\) 和字符串 \(b\) 的最长公共前缀。

题解

可能比较简单做法吗?
考虑我们枚举所有子串作为公共前缀,如果这个子串长度为\(x\),出现次数为\(y\), 那么任取两个出现次数都可以构成两个以它为前缀的后缀, 统计\(y*(y-1)*x\)为答案就行。

显然的问题来了,题目要求的是最长公共前缀,如果当前统计的不是最长的怎么办呢?
如果\(x\)不是最长的,那么就至少存在一个\(x+1\)的前缀,由于我们统计了每一个前缀, 我们不妨考虑在\(x+1\)时将这个答案减去。

对于每一个长度为\(x\)的子串,\(x-1\)必然也会被统计同样多次,我们把它减去,那么长度为1的前缀会在2处被剪掉,2会被3减, 3会被4减, 除非已经是最长公共前缀了,这样做就没有问题。

对于每一个\(x\),统计答案即为\(y*(y-1)*(x-(x-1)) = y*(y-1)\),也就是遍历sam,对于每一个endpos,将出现次数乘上出现次数减一再乘以当前endpos覆盖区间长度就好了。

题解说按边算贡献,本质差不多,但不失为一种不错角度

实现

没啥好说的,乘法记得都在前面乘个1ll,我直接扣sam板子很快就可以写出(如果不是因为我在学习神秘vim的话

#include <iostream>
#include <cstdio>
#include <vector>
#include <string>
#define ll long long
using namespace std;

int read(){
	int num=0, flag=1; char c=getchar();
	while(!isdigit(c) && c!='-') c=getchar();
	if(c=='-') flag=-1, c=getchar();
	while(isdigit(c)) num=num*10+c-'0', c=getchar();
	return num*flag; 
}

int readc(){
	char c=getchar();
	while(c<'a' || c>'z') c=getchar();
	return c-'a'; 
}

const int N = 500500;
int n, m;
char s[N];

ll ans = 0;

namespace sam{
	struct{
		int len, link, siz=0;
		int ch[26];
	}st[N<<1];
	vector<int> ptr[N<<1];
	int las, sz;
	
	void init(){
		las=1, sz=1;
		st[1].len=0, st[1].link=0; 
	}
	
	void extend(int c){
		int cur=++sz, p=las;
		st[cur].len=st[las].len+1, st[cur].siz=1;
		while(p && !st[p].ch[c]) 
			st[p].ch[c]=cur, p=st[p].link;
		
		if(p){
			int nex = st[p].ch[c];
			if(st[p].len+1 == st[nex].len){
				st[cur].link = nex;
			}else{
				int clone = ++sz;
				st[clone].len=st[p].len+1, st[clone].link = st[nex].link;
				for(int i=0; i<26; i++) st[clone].ch[i] = st[nex].ch[i];
				st[cur].link=clone, st[nex].link=clone;
				
				while(st[p].ch[c]==nex) st[p].ch[c]=clone, p=st[p].link;
			}
		}else{
			st[cur].link=1;
		}
		
		las = cur;
	}
	
	void dfsPtr(int x){
		for(int i=0; i<ptr[x].size(); i++) {
			dfsPtr(ptr[x][i]);
			st[x].siz += st[ptr[x][i]].siz;
		}
	}
	
	void buildPtr(){
		for(int i=1; i<=sz; i++) {
			ptr[st[i].link].push_back(i);
		}
		dfsPtr(1);
	}
	
	void clac1(){
		for(int i=1; i<=sz; i++){
			if(st[i].siz!=0)
            	ans += 1ll*st[i].siz * (st[i].siz - 1) * (st[i].len - st[st[i].link].len);
        } 
	}
}

string str;

int main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> str;
	for(int i=0; i<str.size(); i++){
		s[i+1] = str[i]-'a';
	} n = str.size();
	sam::init();
	for(int i=1; i<=n; i++) sam::extend(s[i]); 
	sam::buildPtr();
	sam::clac1();
    ans *= -1;
    // ll sum = 0;
    // for(int i=1; i<=n; i++){
    //     ans += sum + 1ll*(i-1)*(n-i+1);
    //     sum += n-i+1;
    // }
	ans += 1ll*(n-1)*n*(n+1)/2;
	cout << ans << endl;
	
	return 0;
}
posted @ 2024-08-04 21:33  ltdJcoder  阅读(7)  评论(0编辑  收藏  举报