BZOJ 1921: [Ctsc2010]珠宝商 点分治套SAM
题目链接
首先可以发现两种算法
一. 暴力处理
对"特征串"建SAM
枚举路径的一个端点,\(dfs\)另一个端点,同时维护在SAM上的位置.
每到一个位置会有SAM上对应节点的right集合大小的贡献
复杂度\(O(size^2)\)
二. 处理经过一个点的所有路径
设这个点是\(u\),字符为\(a[u]\)
需要建出正反特征串的后缀树
考虑从点\(u\)出发的所有路径(\(a[u]\)为字符串的开头),统计出以特征串的每一位为起始的这些串的数量
同理将这些串\(reverse\)(\(a[u]\)为串的最后一位),统计出在特征串每一位结束的串的数量
对应位上两组串数量的乘积的和即为贡献,因为某正串在一位起始,某反串在这位结束,即可拼出一个路径
由于\(dfs\)时正串是每次在末尾加字符维护在特征串中起始位置,反串是每次开头加字符维护结束位置,将特征串和路径串翻转后即为同一个问题,我们只考虑\(push\ front\)维护结束位置
后缀自动机的转移只能支持末尾插入,于是需要利用后缀树
后缀树上一条边会对应原串中的一段区间
转移时需要注意从上往下走了不满一条边的情况,此时大概需要走到儿子处,注意判无转移时无解
每次在节点上打标记,最后全部下放到叶子处 统计每个位置的出现次数
复杂度\(O(size+m)\)
然后使用点分治,若分治的大小\(size>\sqrt{m}\)使用方法2, 否则暴力做方法1
这里需要注意同一子树的去重 在去重的时候应使用对应的方法保证复杂度
易知复杂度为\(O((n+m)\sqrt{m})\)
代码如下
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>
using namespace std;
#define ll long long
#define rep(i,x,y) for(int i=(x);i<=(y);++i)
#define travel(i,x) for(int i=h[x];i;i=pre[i])
inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s == t ? t = (s = buf) + fread(buf, 1, IN_LEN, stdin), (s == t ? -1 : *s++) : *s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig = false, c = read(); !isdigit(c); c = read()) {
if (c == '-') iosig = true;
if (c == -1) return;
}
for (x = 0; isdigit(c); c = read()) x = ((x + (x << 2)) << 1) + (c ^ '0');
if (iosig) x = -x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh = obuf;
inline void print(char c) {
if (ooh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh = obuf;
*ooh++ = c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x == 0) print('0');
else {
if (x < 0) print('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }
const int N = 50005;
int n, m, num, tot1[N], tot2[N], h[N], pre[N<<1], e[N<<1];
bool vis[N];
char a[N], s[N];
inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
struct sam{
int last, cnt, b[N], str[N], t[N<<1], lazy[N<<1], q[N<<1], siz[N<<1], len[N<<1], fa[N<<1], ch[N<<1][26], son[N<<1][26];
bool isl[N<<1];
inline sam(){ last=cnt=1;}
inline void ins(int c){
int p=last, np=++cnt;
last=np, str[len[np]=len[p]+1]=c, t[np]=len[np];
while(p && !ch[p][c]) ch[p][c]=np, p=fa[p];
if(!p) fa[np]=1;
else{
int q=ch[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else{
int nq=++cnt;
len[nq]=len[p]+1, memcpy(ch[nq], ch[q], sizeof ch[0]);
t[nq]=t[q], fa[nq]=fa[q], fa[q]=fa[np]=nq;
while(ch[p][c]==q) ch[p][c]=nq, p=fa[p];
}
}
siz[np]=isl[np]=1;
}
inline void init(){
rep(i, 1, cnt) ++b[len[i]];
rep(i, 1, m) b[i]+=b[i-1];
rep(i, 1, cnt) q[b[len[i]]--]=i;
for(int i=cnt; i>1; --i) son[fa[q[i]]][str[t[q[i]]-len[fa[q[i]]]]]=q[i], siz[fa[q[i]]]+=siz[q[i]];
}
inline void trans(int &p, int c){ p=ch[p][c];}
void dfs5(int u, int fa, int p, int l){
if(!p) return;
if(l==len[p]) p=son[p][a[u]];
else if(str[t[p]-l]!=a[u]) p=0;
if(!p) return;
++lazy[p];
travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs5(e[i], u, p, l+1);
}
inline void work(int *tot){
rep(i, 2, cnt) lazy[q[i]]+=lazy[fa[q[i]]];
rep(i, 1, cnt) if(isl[i]) tot[len[i]]=lazy[i];
memset(lazy, 0, sizeof lazy);
}
}sam1, sam2;
ll ans;
int root, ctr, Siz, mn, top, lim, siz[N], stk[N];
void dfs1(int u, int fa=0){
siz[u]=1;
int mx=0;
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs1(e[i], u), siz[u]+=siz[e[i]], mx=max(mx, siz[e[i]]);
mx=max(mx, Siz-siz[u]);
if(mx<mn) mn=mx, ctr=u;
}
inline int getctr(int u, int size){ return Siz=mn=size, dfs1(u), ctr;}
void dfs3(int u, int p, int W, int fa=0){
sam1.trans(p, a[u]);
if(p){
ans+=W*sam1.siz[p];
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs3(e[i], p, W, u);
}
}
void dfs2(int u, int fa=0){
stk[++top]=a[u];
int p=1;
for(int i=top; i; --i) sam1.trans(p, stk[i]);
dfs3(root, p, -1);
travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs2(e[i], u);
--top;
}
void dfs4(int u, int fa=0){
dfs3(u, 1, 1);
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs4(e[i], u);
}
void solve(int u, int fa=0, int v=0){
int size=Siz;
if(size<=lim){
if(fa){
stk[top=1]=a[fa];
root=v, dfs2(v);
}
dfs4(u);
}
else{
if(fa){
sam1.dfs5(v, fa, sam1.son[1][a[fa]], 1), sam2.dfs5(v, fa, sam2.son[1][a[fa]], 1);
sam1.work(tot1), sam2.work(tot2);
rep(i, 1, m) ans-=tot1[i]*tot2[m-i+1];
}
sam1.dfs5(u, 0, 1, 0), sam2.dfs5(u, 0, 1, 0);
sam1.work(tot1), sam2.work(tot2);
rep(i, 1, m) ans+=tot1[i]*tot2[m-i+1];
vis[u]=1;
travel(i, u) if(!vis[e[i]]) solve(getctr(e[i], siz[e[i]]<siz[u]?siz[e[i]]:size-siz[u]), u, e[i]);
}
}
int main() {
read(n), read(m);
lim=sqrt(m);
rep(i, 2, n){
static int x, y;
read(x), read(y);
add(x, y), add(y, x);
}
while(isspace(a[1]=read()));
rep(i, 2, n) a[i]=read();
rep(i, 1, n) a[i]-='a';
while(isspace(s[1]=read()));
rep(i, 2, m) s[i]=read();
rep(i, 1, m) sam1.ins(s[i]-='a');
for(int i=m; i; --i) sam2.ins(s[i]);
sam1.init(), sam2.init();
solve(getctr(1, n));
return printf("%lld", ans), 0;
}