题解 [CF204E] Little Elephant and Strings
算是比较好想的了
考虑每个串的每个前缀有哪些后缀是合法的
发现是这个前缀对应的节点到根节点路径上的一个前缀
那么可以倍增找这个前缀的长度
问题就变为对每个节点计算这个节点对应的子串被多少个原串包含
这其实就是一个树上数颜色问题
- 关于树上数颜色:
一个最经典的做法显然是 dsu on tree,然而并不好写
所以还有一个较好写的做法:将每种颜色对应的点按 dfs 序排序后差分
即在每个点产生 1 贡献,在相邻两点 lca 处产生 -1 贡献
因为这是 SAM,论文上其实也有证明从每个点向上爆跳,若当前点已被这种颜色计算过了就 break 的复杂度是 \(O(n\sqrt n)\) 的
那么写树上数颜色的复杂度是 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define pb push_back
#define ll long long
//#define int long long
int n, k;
char *s[N], t[N];
int head[N], ecnt;
vector<int> endpos[N];
int len[N], fail[N], tr[N][26], dep[N], fa[23][N], lg[N], id[N], val[N], now, tot, cnt;
struct edge{int to, next;}e[N];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
void init() {fail[now=tot=0]=-1;}
void insert(char c) {
c-='a';
if (tr[now][c]) {
int cur=tr[now][c];
if (len[cur]==len[now]+1) now=cur;
else {
int cln=++tot;
len[cln]=len[now]+1;
fail[cln]=fail[cur];
for (int i=0; i<26; ++i) tr[cln][i]=tr[cur][i];
for (; ~now&&tr[now][c]==cur; tr[now][c]=cln,now=fail[now]);
fail[cur]=cln;
now=cln;
}
}
else {
int cur=++tot;
len[cur]=len[now]+1;
int p, q;
for (p=now; ~p&&!tr[p][c]; tr[p][c]=cur,p=fail[p]);
if (p==-1) fail[cur]=0;
else if (len[q=tr[p][c]]==len[p]+1) fail[cur]=q;
else {
int cln=++tot;
len[cln]=len[p]+1;
fail[cln]=fail[q];
for (int i=0; i<26; ++i) tr[cln][i]=tr[q][i];
for (; ~p&&tr[p][c]==q; tr[p][c]=cln,p=fail[p]);
fail[cur]=fail[q]=cln;
}
now=cur;
}
}
void dfs1(int u) {
// cout<<"dfs1: "<<u<<endl;
id[u]=++cnt;
for (int i=1; i<23; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
fa[0][v]=u;
dep[v]=dep[u]+1;
dfs1(v);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
void dfs2(int u) {
for (int i=head[u]; ~i; i=e[i].next)
dfs2(e[i].to), val[u]+=val[e[i].to];
}
signed main()
{
scanf("%d%d", &n, &k);
init();
memset(head, -1, sizeof(head));
for (int i=1,len; i<=n; ++i) {
scanf("%s", t+1);
len=strlen(t+1);
s[i]=new char[len+5];
for (int j=1; j<=len+1; ++j) s[i][j]=t[j];
now=0;
for (int j=1; j<=len; ++j) insert(s[i][j]), endpos[i].pb(now);
}
// cout<<"fail: "; for (int i=1; i<=tot; ++i) cout<<fail[i]<<' '; cout<<endl;
for (int i=1; i<=tot; ++i) add(fail[i], i);
for (int i=1; i<=tot+1; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
// cout<<"tot: "<<tot<<endl;
dep[0]=1; dfs1(0);
for (int i=1; i<=n; ++i) {
sort(endpos[i].begin(), endpos[i].end(), [](int a, int b){return id[a]<id[b];});
// cout<<"endpos: "; for (auto it:endpos[i]) cout<<it<<' '; cout<<endl;
for (auto it:endpos[i]) ++val[it];
for (int j=1; j<endpos[i].size(); ++j) --val[lca(endpos[i][j-1], endpos[i][j])];
}
dfs2(0);
// cout<<"dep: "; for (int i=1; i<=tot; ++i) cout<<dep[i]<<' '; cout<<endl;
for (int i=1; i<=n; ++i) {
ll ans=0;
int u=0, len=strlen(s[i]+1);
for (int j=1,t; j<=len; ++j) {
t=u=tr[u][s[i][j]-'a'];
// cout<<"u: "<<u<<endl;
for (int l=lg[dep[t]]-1; ~l; --l) {
// cout<<"fa: "<<fa[l][t]<<endl;
if (val[fa[l][t]]<k) t=fa[l][t];
}
// cout<<"t: "<<t<<endl;
if (val[t]<k) t=fa[0][t];
// cout<<"t: "<<t<<endl;
ans+=::len[t];
}
printf("%lld ", ans);
}
printf("\n");
// for (int i=1; i<=n; ++i) assert(val[fail[i]]>=val[i]);
// cout<<"val: "; for (int i=0; i<=tot; ++i) cout<<val[i]<<' '; cout<<endl;
return 0;
}