树哈希小记
树同构问题
给出两棵树,问是否同构。
同构:存在一种将结点重新编号的方案,使得两棵树完全相等。
树哈希的构造
设 \(h[u]\) 表示 \(u\) 子树的哈希值,我们容易想到:
或者
然后带个模数。
OI-wiki 上有一种实现简单的不容易被卡的哈希构造方法,建议使用这种方法:
其中 \(c\) 可以用 \(\text{xor shift}\)。
点击查看代码
mt19937 rnd(time(0));
const ull _=1ull*rnd()*rnd()*rnd()*rnd();
ull xh(ull x){
x^=_;
x^=x<<13;
x^=x>>7;
x^=x<<17;
x^=_;
return x;
}
void dfs(ll u,ll fa){
f[u]=1;
for(ll v:to[u])
if(v!=fa){
dfs(v,u);
h[u]+=xh(h[v]);
}
}
例题
题意:给出两棵树,判断是否重构。 \(1\le n\le 10^6\)
模板题,上面是正解。
这里的树是无根的。
考虑上面的树哈希,子树具有可加可减性,可以换根,然后选取 hash 值最小的那个。
当然,对于一棵树找出他的重心,直接拿重心作为根。
一棵树可能有两个重心,如果重心个数不一样,一定不同构;如果是两个重心,枚举一下匹配就好。
\(O(nm)\)。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define mkp make_pair
#define fi first
#define se second
#define pir pair<ll,ll>
#define pb push_back
using namespace std;
const ll maxn=1e6+10;
ll n,m,u,v;
ull f[maxn],g[maxn],val;
vector<ll>to[maxn];
unordered_map<ull,ll>mp;
mt19937 rnd(time(0));
const ull _=1ull*rnd()*rnd()*rnd()*rnd();
ull xh(ull x){
x+=_;
x^=x<<13;
x^=x>>7;
x^=x<<17;
x-=_;
return x;
}
void dfs1(ll u,ll fa){
f[u]=1;
for(ll v:to[u])
if(v!=fa){
dfs1(v,u);
f[u]+=xh(f[v]);
}
}
void dfs2(ll u,ll fa){
val=min(val,f[u]+g[u]);
for(ll v:to[u])
if(v!=fa){
g[v]=xh(g[u]+f[u]-xh(f[v]));
dfs2(v,u);
}
}
int main(){
scanf("%lld",&m);
for(ll i=1;i<=m;i++){
scanf("%lld",&n);
for(ll i=1,x;i<=n;i++){
scanf("%lld",&x);
if(x) to[x].pb(i), to[i].pb(x);
}
val=-1;
dfs1(1,0);
dfs2(1,0);
if(!mp.count(val)) mp[val]=i;
printf("%lld\n",mp[val]);
for(ll i=1;i<=n;i++) to[i].clear();
}
return 0;
}
题意:给出两棵有根树,其中 \(|T_2|\le |T_1|\le |T_2|+k\),判断是否存在删掉 \(T_1\) 中一些叶子且不能删根,使得 \(T_1\) 和 \(T_2\) 同构。
\(1\le n\le 5\times 10^5,\space 0\le k\le 5\)
注意到 \(k\) 很小,可以使用一些小规模的爆搜。
设计函数 \(chk(u_1,u_2)\) 表示两棵树的点 \(u_1,u_2\) 对应子树在删一些点后是否可能同构。
考虑 \(u_1,u_2\) 的儿子,会发现已经同构的儿子子树一定可以直接消去。
可以反证,如果两个同构的儿子子树不匹配,一个匹配大的,一个匹配小的,那么和大的和小的匹配是等价的。
因此,利用树哈希,不断消去同构的子树,\(u_1\) 和 \(u_2\) 都会剩下一些不匹配的儿子。
\(u_1\) 剩下的儿子个数一定不能超过 \(k\),否则不合法。
接下来考虑直接枚举 \(k\) 的全排列来暴力匹配,因为 \(k\) 很小所以可以接受。
很多地方可以剪枝,加上剪枝之后时间复杂度仍然比较玄学,\(k\) 很小而且还是子树加和,反正就是能过。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define mkp make_pair
#define fi first
#define se second
#define pir pair<ll,ll>
#define pb push_back
using namespace std;
const ll maxn=1e6+10;
ll t,n1,n2,r1,r2,siz1[maxn],siz2[maxn];
ull h1[maxn],h2[maxn];
vector<ll>to1[maxn],to2[maxn];
mt19937 rnd(time(0));
const ull _=1ull*rnd()*rnd()*rnd()*rnd()*rnd()*rnd()*rnd();
ull xh(ull x){
x^=_;
x^=x<<13;
x^=x>>7;
x^=x<<17;
x^=_;
return x;
}
void dfs1(ll u){
h1[u]=1, siz1[u]=1;
for(ll v:to1[u]){
dfs1(v);
h1[u]+=xh(h1[v]), siz1[u]+=siz1[v];
}
}
void dfs2(ll u){
h2[u]=1, siz2[u]=1;
for(ll v:to2[u]){
dfs2(v);
h2[u]+=xh(h2[v]), siz2[u]+=siz2[v];
}
}
bool cmp1(ll u,ll v) {return h1[u]<h1[v]; }
bool cmp2(ll u,ll v) {return h2[u]<h2[v]; }
ll w1[maxn],w2[maxn],l1,l2,per[maxn][5];
vector<ll>t1[maxn],t2[maxn];
bool chk(ll u1,ll u2){
if(siz1[u1]<siz2[u2]) return false;
if(siz1[u1]==siz2[u2]) return h1[u1]==h2[u2];
l1=l2=0;
t1[u1].clear(), t2[u2].clear();
for(ll v:to1[u1]) w1[++l1]=v;
for(ll v:to2[u2]) w2[++l2]=v;
sort(w1+1,w1+1+l1,cmp1);
sort(w2+1,w2+1+l2,cmp2);
ll p1=1, p2=1;
while(p1<=l1&&p2<=l2){
if(h1[w1[p1]]==h2[w2[p2]]) ++p1, ++p2;
else if(h1[w1[p1]]<h2[w2[p2]]) t1[u1].pb(w1[p1]), ++p1;
else t2[u2].pb(w2[p2]), ++p2;
}
while(p1<=l1) t1[u1].pb(w1[p1]), ++p1;
while(p2<=l2) t2[u2].pb(w2[p2]), ++p2;
if(t1[u1].size()<t2[u2].size()||t1[u1].size()>siz1[u1]-siz2[u2]) return false;
ll l=t1[u1].size();
for(ll i=0;i<l;i++) per[u1][i]=i+1;
do{
ll fl=1;
for(ll i=0;i<l;i++){
ll j=per[u1][i]-1;
if(j<t2[u2].size()){
if(siz1[t1[u1][i]]<siz2[t2[u2][j]]||!chk(t1[u1][i],t2[u2][j])){
fl=0; break;
}
}
}
if(fl) return true;
}while(next_permutation(per[u1],per[u1]+l));
return false;
}
int main(){
// freopen("iso4.in","r",stdin);
// freopen("iso.out","w",stdout);
scanf("%*lld%lld%*lld",&t);
while(t--){
scanf("%lld",&n1);
for(ll i=1,x;i<=n1;i++){
scanf("%lld",&x);
if(x==-1) r1=i;
else to1[x].pb(i);
}
scanf("%lld",&n2);
for(ll i=1,x;i<=n2;i++){
scanf("%lld",&x);
if(x==-1) r2=i;
else to2[x].pb(i);
}
dfs1(r1);
dfs2(r2);
if(chk(r1,r2)) puts("Yes");
else puts("No");
for(ll i=1;i<=n1;i++) to1[i].clear();
for(ll i=1;i<=n2;i++) to2[i].clear();
}
return 0;
}