LOJ6681. yww 与树上的回文串
LOJ6681. yww 与树上的回文串
题意:给定一棵边上带 01 权值的树,求有多少对 \((x,y)\) 满足 \(x<y\) 且 \(x\) 到 \(y\) 路径上的边权拼起来是回文串。
\(n\leq 5\times 10^4\)。
tag:点分治 AC自动机 根号分治。
点分治,分治中心为 \(G\) 时有三种贡献可能:
- \(G\to x\) 回文。
- \(x\to G\to y\) 回文,且 \(x,y\) 属于不同子树,深度相等。
- \(x\to G\to y\) 回文,且 \(x,y\) 属于不同子树,深度不相等,设 \(dep_x>dep_y\)。
第一种情况用 hash 判断每个点对应的串是否是回文的,第二种情况用 umap 统计,第三种先建立整个分治范围内以 \(G\) 为根的 AC 自动机,fail 树上的祖先都对应自己的真后缀。
考虑 \(x\to G\to y\) 实际可以划分成三段:\(x\to o\to G\to y\),满足 \(o\to G\) 段回文,\(o\to x\) 和 \(G\to y\) 段相等,也就是说 \(G\to y\) 是 \(G\to x\) 的后缀,点 \(y\) 在 fail 树上会是点 \(x\) 的祖先。
引理:一个串的前缀回文串可以划分为 \(O(\log n)\) 个值域不交的等差数列。
证明见 OI-wiki。
先把等差数列预处理放到每个点上。在 fail 树上 dfs,维护出若干等差数列。逐棵把子树统计贡献再加入。
设当前点是 \(x\) 计算贡献,则合法的 \(y\) 一定在 dfs 栈中。
假设点 \(x\) 长度为 \(l\) 的前缀是回文的,那么 fail 树上的祖先中如果出现了 \(dep_y=dep_x-l\) 的点 \(y\),则 \((x,y)\) 会统计入答案。
对于维护出的等差数列首项,末项,公差分别为 \(s,e,d\),表示当前长度在 \([s,e]\) 中,且 \(\bmod d\) 意义下同余 \(s\) 的前缀均回文,则应该统计 fail 树的祖先中原树深度为以 \(a_1=dep_x-e\) 为首项,\(a_2=dep_x-s\) 为末项,公差为 \(d\) 中的元素的点 \(y\) 的标记之和。
统计时注意到不能和之前一样逐棵统计贡献并加入,因为 Trie 树上的一棵以根节点为根的连通子树不一定在 fail 树上也是,所以一起统计之后对于每棵子树再算一次答案容斥掉。
设置一个阈值 \(B\),维护一个 \(B\times B\) 大小的数组 \(c_{i,j}\) 表示当前 dfs 栈中的点,原树深度 \(\bmod i\) 等于 \(j\) 的标记之和,再维护一个大小为 \(n\) 的数组 \(t_i\) 表示原树深度为 \(i\) 的标记和,都容易加入和删除点的贡献。
当 \(d<B\) 时可以把需要计算的贡献先挂到树上,再扫一遍贡献到答案上;当 \(d\ge B\) 时直接枚举合法深度在 \(t\) 中统计答案。
理论上在 \(B=O(\sqrt n)\) 时取得最优复杂度 \(O(n\log^2 n+n\sqrt n)\)。
为了稍微好写一点,可以把差分再离线计算贡献的过程换成树状数组,且 \(B\) 取到 \(2\) 左右的时候跑得比较快,原因是很难构造数据来卡。
#include<bits/stdc++.h>
#define For(i,a,b) for(int i=(a),i##END=(b);i<=i##END;i++)
#define Rof(i,b,a) for(int i=(b),i##END=(a);i>=i##END;i--)
#define go(u) for(int i=head[u];i;i=nxt[i])
#define pi pair<int,int>
#define fi first
#define se second
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
const int N=5e4+10,p1=147744151,p2=666528221,base=131;
int pw1[N],pw2[N],h1[N],h2[N],H1[N],H2[N];//hash
int n,ans;vector<pi> g[N];
void add(int u,int v,int w){g[u].push_back(pi(v,w));}
int used[N];
int dep[N],ok[N];//深度,是否回文
int son[N][2],idx,vis[N];//Trie
int head[N],to[N],nxt[N],cnt;
void add(int u,int v){
to[++cnt]=v,nxt[cnt]=head[u],head[u]=cnt;
}
int q[N];int L,R; int fail[N];//ACAM
void build(){
For(i,0,idx)fail[i]=head[i]=0;L=1,R=0;
For(i,0,1)if(son[0][i])q[++R]=son[0][i];
while(L<=R){
int u=q[L++];
For(i,0,1)if(son[u][i])fail[son[u][i]]=son[fail[u]][i],q[++R]=(son[u][i]);
else son[u][i]=son[fail[u]][i];
}cnt=0;For(i,1,idx)add(fail[i],i);
}
struct node{int s,e,d;};
vector<node> pal[N];
int Dep[N];
#define rb(x) x.back()
int ban[N];
void prework(int u,int U,int f,int op=0){
assert(dep[u]==Dep[U]);
if(!op){
pal[u]=pal[f];
ok[u]=0;if(f){
if(h1[u]==H1[u]&&h2[u]==H2[u]){
ok[u]=1;
ans++;//#1
if(!pal[u].size())pal[u].push_back((node){dep[u],dep[u],1});
else if(rb(pal[u]).e+rb(pal[u]).d==dep[u])rb(pal[u]).e+=rb(pal[u]).d;
else{
auto it=rb(pal[u]);
if(it.s==it.e)rb(pal[u])=(node){it.s,dep[u],dep[u]-it.s};
else pal[u].push_back((node){dep[u],dep[u],1});
}
}
}
}
for(auto x:g[u]){
int v=x.fi,w=x.se;if(v==f||used[v])continue;
if(!op){
h1[v]=(1ll*h1[u]*base+w)%p1;
h2[v]=(1ll*h2[u]*base+w)%p2;
H1[v]=(1ll*w*pw1[dep[u]]+H1[u])%p1;
H2[v]=(1ll*w*pw2[dep[u]]+H2[u])%p2;
}
dep[v]=dep[u]+1;
int to=son[U][w-1];
if(!to)to=son[U][w-1]=++idx,
vis[idx]=0,son[idx][0]=son[idx][1]=0,Dep[idx]=Dep[U]+1;
prework(v,to,u,op);
}
}
void calc(int u,int U,int f){
ans+=vis[U];//#2
assert(dep[u]==Dep[U]);
for(auto x:g[u]){
int v=x.fi;if(v==f||used[v])continue;
calc(v,son[U][x.se-1],u);
}
}
void addin(int u,int U,int f){
if(U)vis[U]++;
for(auto x:g[u]){
int v=x.fi;if(v==f||used[v])continue;
addin(v,son[U][x.se-1],u);
}
}
vector<node> PAL[N];
void find(int u,int U,int f){
for(auto p:pal[u])PAL[U].push_back(p);
for(auto x:g[u]){
int v=x.fi;if(v==f||used[v])continue;
find(v,son[U][x.se-1],u);
}
}
const int B=2;
int t[N];
#define lowbit(x) (x&-x)
int cc[N];
inline void add(int u,int v,int* c){for(int i=u;i<=idx;i+=lowbit(i))c[i]+=v;}
inline int ask(int u,int *c,int s=0){for(int i=u;i;i-=lowbit(i))s+=c[i];return s;}
void dfs(int u,int f,int op=1){
for(auto p:PAL[u]){
int l=Dep[u]-p.e,r=Dep[u]-p.s,d=p.d;
if(d==1)ans+=op*(ask(r,cc)-(l?ask(l-1,cc):0));
else for(int i=l;i<=r;i+=d)ans+=op*t[i];
}
t[Dep[u]]+=vis[u];
if(vis[u])add(Dep[u],vis[u],cc);
go(u)dfs(to[i],u,op);
t[Dep[u]]-=vis[u];
if(vis[u])add(Dep[u],-vis[u],cc);
}
int all_num,fa[N],sz[N],mx[N],rt;
void getr(int u,int f){
sz[u]=1,mx[u]=0;for(auto x:g[u]){
int v=x.fi;
if(!used[v]&&v!=f)getr(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
}mx[u]=max(mx[u],all_num-sz[u]);if(mx[u]<mx[rt])rt=u;
}
void solve(int u){
used[u]=1;
h1[u]=h2[u]=H1[u]=H2[u]=dep[u]=Dep[0]=0,son[0][0]=son[0][1]=0,idx=0;
prework(u,0,0),build();
For(i,0,idx)vis[i]=0,vector<node>().swap(PAL[i]);
for(auto x:g[u]){
int v=x.fi,V=son[0][x.se-1];if(used[v])continue;
calc(v,V,u);
addin(v,V,u);
}
find(u,0,0),dfs(0,0);
for(auto x:g[u]){
int v=x.fi,w=x.se;if(used[v])continue;
h1[u]=h2[u]=H1[u]=H2[u]=dep[u]=Dep[0]=0,son[0][0]=son[0][1]=0,idx=0;
int to=son[0][w-1];
if(!to)to=son[0][w-1]=++idx,
vis[idx]=0,son[idx][0]=son[idx][1]=0,Dep[idx]=Dep[0]+1;
prework(v,to,u,1),build();
For(i,0,idx)vis[i]=0,vector<node>().swap(PAL[i]);
addin(v,to,u),find(v,to,u),dfs(0,0,-1);
}
for(auto x:g[u]){int v=x.fi;if(!used[v])getr(v,u),all_num=sz[v],rt=0,getr(v,u),solve(rt);}
}
signed main(){
For(i,2,n=read()){int u=read(),v=read(),w=read()+1;add(u,v,w),add(v,u,w);}
pw1[0]=pw2[0]=1,mx[0]=1e9;
For(i,1,n)pw1[i]=1ll*pw1[i-1]*base%p1,pw2[i]=1ll*pw2[i-1]*base%p2;
all_num=n,rt=0,getr(1,0),solve(rt);cout<<ans<<endl;
return 0;
}