[BZOJ1921][CTSC2010]珠宝商(点分治+后缀自动机)
[BZOJ1921][CTSC2010]珠宝商(点分治+后缀自动机)
题面
给出一个\(n\)个点的树,每个点上都有一个字符。再给出一个长度为\(m\)的特征串\(str\)。求树上所有简单路径经过节点的字符按顺序连接起来后的串在特征串中的出现次数之和。
分析
显然发现是点分治。考虑如何计算经过一个点\(x\)的所有路径的答案。
每次计算经过重心的字符串,可以拆成两部分:某个点到重心的字符串 + 重心到某个点的字符串。由于路径有向上和向下的方向之分,我们对特征串的正串和反串分别建SAM,记为\(T_1,T_2\)
容易想到一个简单的暴力:
先找到当前分治子树中的所有节点。对于子树中的每个节点,暴力从它开始dfs,同时在\(T_1\)上匹配,累加当前SAM节点代表的串在特征串中的出现次数.
这样的复杂度是\(O(sz[x]^2)\),其中\(sz[x]\)为x的子树大小。
直接暴力在节点很多的时候不太优秀,现在我们考虑另一种暴力:
我们从\(x\)从上往下dfs.找到某个点到\(x\)的路径上的字符串。注意是往前加字符,所以不能直接利用SAM的转移找到下一个节点。那么我们可以另外维护一个转移函数\(pre(x,c)\)表示\(x\)节点加上字符\(c\)
显然有
其中\(\max(right(x))\)表示出现位置的最大值(实际上任选一个出现位置都可以)。这是因为从x的结尾往前,去掉长度为\(len(link(x))\)的后缀,前面那个的字符就是加在\(link(x)\)对应的串前面的字符,而它又属于\(x\)状态对应的字符串。
有了\(pre\)之后我们就可以一边在原树上dfs一边向下匹配。设当前在SAM上的节点为\(d\),如果匹配长度\(<len(d)\),直接在原字符串S上看加入这个字符是否仍然合法. 否则就看\(pre(x,\text{当前字符})\)是否为空。
对于匹配的每个节点,我们在自动机上打一个标记。把所有标记打完后遍历整个自动机,累加parent树祖先的标记。这样就得到了每个节点的出现次数,即这个\(str\)的子串在多少从\(x\)向下的链中出现过。然后枚举\(str\)的每个前缀,计算出\(T_1\)中对应节点的出现次数,和\(T_2\)中对应节点的出现次数,再相乘得到答案。因为\(str\)的子串可能在路径向上的一段出现,也可能在路径向下的一段出现,而把这任意两段路径连起来都是合法的,所以要相乘。
因为长度为\(m\)的字符串的SAM大小和转移数都是\(O(m)\)的,容易发现这种暴力的复杂度为\(O(m)\)
我们考虑如何平衡这两种暴力的复杂度。如果我们在子树大小\(>B\)的时候执行第二种暴力,这样的子树不会超过\(\frac{n}{B}\)个,复杂度为\(O(\frac{n}{B} \times m)=O(\frac{nm}{B})\).如果我们分治到一个\(<B\)的子树就不再继续划分,那么这样的子树也不会超过\(\frac{n}{B}\),复杂度为\(O(\frac{n}{B} \times B^2)=O(nB)\)
总复杂度为\(O(n(\frac{m}{B}+B))\).当\(B=\sqrt{m}\)时最优,为\(O(n\sqrt{m})\)
注意一个细节,清除重复子树的时候也需要按照大小是否小于\(B\)来判断,否则会被卡成\(O(nm)\)
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define maxn 200000
#define maxc 26
using namespace std;
typedef long long ll;
int n,m,B;
char pc[maxn+5];//每个节点的字母
char sp[maxn+5];//特征串
struct edge {
int from;
int to;
int next;
} E[maxn*2+5];
int esz=1;
int head[maxn+5];
void add_edge(int u,int v) {
esz++;
E[esz].from=u;
E[esz].to=v;
E[esz].next=head[u];
head[u]=esz;
}
bool vis[maxn+5];
struct SAM {
#define link(x) (t[x].link)
#define len(x) (t[x].len)
struct node {
int ch[maxc];
int pre[maxc];//代表的串前面加上字符c之后,转移到的自动机节点
int link;
int len;
int sz;
int tag;
int maxpos;
} t[maxn+5];
char str[maxn+5];//当前自动机对应的串
// vector<int>fail[maxn+5];
int id[maxn+5];
int seq[maxn+5];
inline void rsort() {
static int buck[maxn+5];
for(int i=1; i<=m; i++) buck[i]=0;
for(int i=1; i<=ptr; i++) buck[len(i)]++;
for(int i=1; i<=m; i++) buck[i]+=buck[i-1];
for(int i=1; i<=ptr; i++) seq[buck[len(i)]--]=i;
}
const int root=1;
int last=root;
int ptr=1;
int extend(int c,int pos) {
int p=last,cur=++ptr;
len(cur)=len(p)+1;
t[cur].sz=1;
t[cur].maxpos=pos;
while(p&&t[p].ch[c]==0) {
t[p].ch[c]=cur;
p=link(p);
}
if(p==0) link(cur)=root;
else {
int q=t[p].ch[c];
if(len(p)+1==len(q)) link(cur)=q;
else {
int clo=++ptr;
len(clo)=len(p)+1;
for(int i=0; i<maxc; i++) t[clo].ch[i]=t[q].ch[i];
link(clo)=link(q);
link(q)=link(cur)=clo;
while(p&&t[p].ch[c]==q) {
t[p].ch[c]=clo;
p=link(p);
}
}
}
last=cur;
return cur;
}
void ini_tree() {
for(int i=ptr; i>=1; i--) {
int y=seq[i],x=link(y);
t[x].pre[str[t[y].maxpos-len(x)]-'a']=y;//预处理
t[x].sz+=t[y].sz;
if(!t[x].maxpos) t[x].maxpos=t[y].maxpos;
}
}
void sum_tag() {
for(register int i=1; i<=ptr; i++) {
register int y=seq[i],x=link(y);
t[y].tag+=t[x].tag;
}
}
void insert(char *s) {
int len=strlen(s+1);
for(int i=1; i<=len; i++) str[i]=s[i];
for(int i=1; i<=len; i++) id[i]=extend(str[i]-'a',i);
rsort();
ini_tree();
}
void add_tag(int x,int fa,int nd,int matlen) {
if(matlen==len(nd)) nd=t[nd].pre[pc[x]-'a'];//如果长度为当前的len,就可以加上字符pc[x]转移
else if(str[t[nd].maxpos-matlen]!=pc[x]) nd=0;
if(nd==0) return;
matlen++;
t[nd].tag++;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa&&!vis[y]) add_tag(y,x,nd,matlen);
}
}
inline void clear_tag() {
for(int i=1; i<=ptr; i++) t[i].tag=0;
}
void debug() {
printf("tag:");
for(int i=1; i<=ptr; i++) printf("%d ",t[i].tag);
printf("\n");
}
} T1,T2;
ll ans=0;
int sz[maxn+5];
int f[maxn+5];
void dfs_root(int x,int fa,int tot_sz,int &root) { //找重心
sz[x]=1;
f[x]=0;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa&&!vis[y]) {
dfs_root(y,x,tot_sz,root);
sz[x]+=sz[y];
f[x]=max(f[x],sz[y]);
}
}
f[x]=max(f[x],tot_sz-sz[x]);
if(f[x]<f[root]) root=x;
}
int get_root(int x,int fa,int tot_sz) {
f[0]=tot_sz+1;
int root=0;
dfs_root(x,fa,tot_sz,root);
return root;
}
void bf_ans(int x,int fa,int nd,int type) { //暴力遍历经过x的所有路径求解答案
nd=T1.t[nd].ch[pc[x]-'a'];
if(nd==0) return;
ans+=T1.t[nd].sz*type;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa&&!vis[y]) bf_ans(y,x,nd,type);
}
}
void bf_tree(int x,int fa) { //对当前子树中的每个节点,遍历它的子树,O(size^2)
bf_ans(x,0,T1.root,1); //注意fa要设成0,这样才可以处理"转弯"的路径
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa&&!vis[y]) bf_tree(y,x);
}
}
int top;
int stk[maxn+5];
void bf_del(int x,int fa,int nd,int matlen) {//类似add_tag,找到那些被标记的节点,准备下一步删掉
if(matlen==T1.t[nd].len) nd=T1.t[nd].pre[pc[x]-'a'];
else if(T1.str[T1.t[nd].maxpos-matlen]!=pc[x]) nd=0;
if(nd==0) return;
stk[++top]=nd;
matlen++;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(y!=fa&&!vis[y]) bf_del(y,x,nd,matlen);
}
}
void calc(int x,int fa,int type) { //打标记求解
T1.clear_tag();
T2.clear_tag();
if(fa) {
T1.add_tag(x,fa,T1.t[T1.root].ch[pc[fa]-'a'],1);//要加上父亲节点的匹配
T2.add_tag(x,fa,T2.t[T2.root].ch[pc[fa]-'a'],1);
} else {
T1.add_tag(x,fa,T1.root,0);
T2.add_tag(x,fa,T2.root,0);
}
T1.sum_tag();
T2.sum_tag();
for(register int i=1; i<=m; i++) {
register int x1=T1.id[i];
register int x2=T2.id[m-i+1];//反串,对应位置为m-i+1
ans+=1ll*type*T1.t[x1].tag*T2.t[x2].tag;
}
}
void solve(int x) {
// printf("root=%d\n",x);
if(sz[x]<=B) {
bf_tree(x,0);
return;
}
calc(x,0,1);
vis[x]=1;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(!vis[y]) {
if(sz[y]<=B){
top=0;
bf_del(y,x,T1.t[1].pre[pc[x]-'a'],1);
for(int j=1;j<=top;j++) bf_ans(y,0,stk[j],-1);
}else{
calc(y,x,-1);
}
}
}
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
if(!vis[y]) solve(get_root(y,x,sz[y]));
}
}
int main() {
//#ifdef DEBUG
// freopen("1.in","r",stdin);
//#endif
int u,v;
scanf("%d %d",&n,&m);
B=8*sqrt(n);
for(int i=1; i<n; i++) {
scanf("%d %d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
scanf("%s",pc+1);
scanf("%s",sp+1);
T1.insert(sp);
reverse(sp+1,sp+1+m);
T2.insert(sp);
solve(get_root(1,0,n));
printf("%lld\n",ans);
}