题解 LOJ2065 「SDOI2016」模式字符串
LOJ2065 「SDOI2016」模式字符串
题目大意
给定一棵树,每个节点上有一个字符。给定一个字符串 \(S\)。问有多少对有序点对 \((u, v)\),满足树上从 \(u\) 到 \(v\) 路径上的字符顺次连接形成的字符串,是由 \(S\) 重复若干次(必须是整数次)得到的?
数据范围:多组测试数据。\(1\leq \sum n, \sum m\leq 10^6\),其中 \(n\), \(m\) 分别表示树的节点数和字符串 \(S\) 的长度。
本题题解
点分治。
考虑经过当前分治中心\(u\)的点对数量。
这种数点对数的问题,有一个套路。我们可以依次考虑\(u\)的每个儿子,看用当前的儿子,能和之前已经考虑过的所有儿子,组成多少点对。这样所有合法的点对都会被恰好计算一次。
现在搜索\(u\)的一个儿子\(v\)的子树。对子树里的每个点,考虑它到\(u\)的有向路径形成的串。在搜索的过程中,我们每次要在当前串的“开头”处添加一个字符(即把整个串整体右移一位),没有什么好的数据结构可以维护,于是想到哈希。现在我们要判断,当前的串,是否是“若干个\(s\)”的一个前缀;如果是(称这样的节点是合法的),那我们要知道它最后匹配到的“零头”是多长,也即这个“前缀”的长度\(\bmod m\)的余数是多少。具体地,在搜索时,我们维护一个桶\(buc\)。\(buc[i]\)表示有多少个合法的节点\(x\),使得\(x\)到\(u\)的串的长度\(\bmod m=i\)。
这样就维护出了所有的前缀。现在我们想知道,\(v\)的子树内每个合法的前缀,能匹配\(v\)之前的子树内的多少合法的后缀。在搜索时,我们用和维护前缀类似的方法来维护后缀。对后缀,我们把\(s\)整体反转,然后也开一个桶,做和匹配前缀时一样的操作即可。
同样,对于\(v\)子树内的所有合法的后缀,我们也要知道它能匹配\(v\)之前的子树内的多少合法的前缀。(这是因为路径是有向的,因此要拿\(v\)内的前缀匹配一次前面的后缀,再拿\(v\)内的后缀匹配一次前面的前缀)。
现在完成了对\(v\)的子树的搜索,也把\(v\)子树的贡献计入了答案。我们得到了两个桶,分别是\(v\)内所有合法前缀的串长\(\bmod m\)的值为\(i\)的点的数量,和\(v\)内所有合法后缀的串长\(\bmod m\)的值为\(i\)的点的数量。现在,\(v\)这棵子树的身份就从“当前子树”,变成了“当前子树之前的子树”。于是拿这两个\(v\)的桶去分别更新两个“全局桶”即可。
注意,桶的大小是\(\min(m,\text{maxdep}_v)\),在更新全局桶和清空小桶时一定不能直接for
到\(m\),否则复杂度就不对了。
除了点分治,其他部分的复杂度是线性的。因此总时间复杂度\(O(n\log n)\)。
参考代码
//problem:loj2065
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
/* ------ by:duyi ------ */ // dysyn1314
const int MAXN=1e6;
const ull BASE=31;
char t[MAXN+5],s[MAXN+5];//t是树上的字符,s是模式串
ull h_pre[MAXN+5],h_suf[MAXN+5],pw[MAXN+5];
int n,m;
struct EDGE{int nxt,to;}edge[MAXN*2+5];
int head[MAXN+5],tot;
inline void add_edge(int u,int v){edge[++tot].nxt=head[u],edge[tot].to=v,head[u]=tot;}
bool vis[MAXN+5];
int SZ,rt,sz[MAXN+5],mxs[MAXN+5];
void _get_root(int u,int fa){
sz[u]=1;
mxs[u]=0;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa||vis[v])continue;
_get_root(v,u);
sz[u]+=sz[v];
mxs[u]=max(mxs[u],sz[v]);
}
mxs[u]=max(mxs[u],SZ-sz[u]);
if(!rt||mxs[u]<mxs[rt])rt=u;
}
void get_root(int fullsize,int temproot){
rt=0;SZ=fullsize;
_get_root(temproot,0);
}
ll ans;
int mxd,pre[MAXN+5],suf[MAXN+5],spre[MAXN+5],ssuf[MAXN+5];
char rootchar;
void dfs(int u,int fa,int dep,ull h){
mxd=max(mxd,dep);
h=h+(ull)t[u]*pw[dep-1];
if(h==h_pre[dep]){
pre[dep%m]++;
if(rootchar==s[dep%m+1])ans+=ssuf[m-dep%m-1];
}
if(h==h_suf[dep]){
suf[dep%m]++;
if(rootchar==s[m-dep%m])ans+=spre[m-dep%m-1];
}
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(v==fa||vis[v])continue;
dfs(v,u,dep+1,h);
}
}
void solve(int u){
//cerr<<"solve "<<u<<endl;
vis[u]=1;
int mxD=0;
rootchar=t[u];
if(rootchar==s[1])spre[0]++;
if(rootchar==s[m])ssuf[0]++;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(vis[v])continue;
mxd=0;
dfs(v,u,1,0);
mxD=max(mxD,mxd);
for(int j=0;j<m&&j<=mxd;++j){
spre[j]+=pre[j];
ssuf[j]+=suf[j];
pre[j]=suf[j]=0;
}
}
for(int i=0;i<m&&i<=mxD;++i)spre[i]=ssuf[i]=0;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(vis[v])continue;
get_root(sz[v],v);
solve(rt);
}
}
void get_hash(char *s,ull *h){
for(int i=m+1;i<=n;++i)s[i]=s[i-m];
for(int i=1;i<=n;++i)h[i]=h[i-1]*BASE+s[i];
}
void clr(){
memset(head,0,sizeof(int)*(n+3));
memset(vis,0,sizeof(bool)*(n+3));
tot=0;
ans=0;
}
int main() {
pw[0]=1;for(int i=1;i<=MAXN;++i)pw[i]=pw[i-1]*BASE;
int Testcases;scanf("%d",&Testcases);while(Testcases--){
scanf("%d%d%s",&n,&m,t+1);for(int i=1;i<=n;++i)t[i]-='A';
for(int i=1,u,v;i<n;++i)scanf("%d%d",&u,&v),add_edge(u,v),add_edge(v,u);
scanf("%s",s+1);for(int i=1;i<=m;++i)s[i]-='A';
get_hash(s,h_pre);reverse(s+1,s+m+1);get_hash(s,h_suf);reverse(s+1,s+m+1);
get_root(n,1);solve(rt);
printf("%lld\n",ans);
clr();
}
return 0;
}