bzoj4754: [Jsoi2016]独特的树叶
正儿八经的树哈希的模板题。
哈希方法是求出每个儿子的哈希值,排序,末尾加上该点的size连成一个串,再对这个串哈希。
因为要求出A所有同构的树的hash值,直接hash的复杂度是n^n log
先随便指定一个根哈希一遍,再树dp换根即可,换根的时候把每个人父亲传过来的hash值也扔进串中,然后对每个人的串求前后缀,就可以方便地计算出它的hash值和它要传给每个儿子的hash值了。
然后对于B树要求去掉一个叶子的hash值,那么一个点传给它的叶子儿子的值就是一个合法的值,此外单独考虑一下根的父亲是叶子的情况即可。
//Achen
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<cmath>
#include<map>
const int N=1e5+7,mod=998244353,bs=1011139;
typedef long long LL;
using namespace std;
int n,tot,ans=-1;
map<LL,int>mp;
LL power[N],fh[N];
template<typename T> void read(T &x) {
char ch=getchar(); x=0; T f=1;
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=-1,ch=getchar();
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; x*=f;
}
int ecnt,fir[N],nxt[N*2],to[N*2],sz[N],ha[N],f[N];
void add(int u,int v) {
nxt[++ecnt]=fir[u]; fir[u]=ecnt; to[ecnt]=v;
nxt[++ecnt]=fir[v]; fir[v]=ecnt; to[ecnt]=u;
}
bool cmp(const int &A,const int &B) { return ha[A]<ha[B]; }
vector<int>v[N];
void get_hash(int x) {
int up=v[x].size();
ha[x]=0;
sort(v[x].begin(),v[x].end(),cmp);
for(int i=0;i<up;i++)
ha[x]=((LL)ha[x]+(LL)ha[v[x][i]]*power[i+1]%mod)%mod;
ha[x]=((LL)ha[x]+sz[x])%mod;
}
void dfs(int x,int fa) {
sz[x]=1; f[x]=fa;
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=fa) {
dfs(to[i],x);
v[x].push_back(to[i]);
sz[x]+=sz[to[i]];
}
get_hash(x);
}
int pr[N];
void make_hash(int x) {
sort(v[x].begin(),v[x].end(),cmp);
int up=v[x].size();
pr[0]=0;
for(int i=0;i<up;i++) pr[i+1]=((LL)pr[i]+power[i+1]*ha[v[x][i]])%mod;
LL has=0;
ha[x]=(pr[up]+tot)%mod;
for(int i=up-1;i>=0;i--) {
if(v[x][i]!=n+2) {
fh[v[x][i]]=((has*power[i]%mod+pr[i])%mod+(tot-sz[v[x][i]]))%mod;
if(ecnt==2*n&&!nxt[fir[v[x][i]]]&&mp[fh[v[x][i]]])
if(ans==-1||v[x][i]<ans) ans=v[x][i];
}
else if(ecnt==2*n&&!nxt[fir[f[x]]]) {
LL tpp=((has*power[i]%mod+pr[i])%mod+(tot-1))%mod;
if(mp[tpp])
if(ans==-1|f[x]<ans) ans=f[x];
}
has=(has+ha[v[x][i]])*bs%mod;
}
}
void dp(int x,int fa) {
if(fa) { ha[n+2]=fh[x]; v[x].push_back(n+2); }
make_hash(x);
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=fa) {
dp(to[i],x);
}
}
int main() {
read(n); power[0]=1; tot=n;
for(int i=1;i<=n;i++) power[i]=power[i-1]*bs%mod;
for(int i=1;i<n;i++) {
int u,v;
read(u); read(v);
add(u,v);
}
dfs(1,0);
dp(1,0);
for(int i=1;i<=n;i++) mp[ha[i]]=1,v[i].clear();
memset(fir,0,sizeof(fir)); ecnt=0; tot++;
for(int i=1;i<=n;i++) {
int u,v;
read(u); read(v);
add(u,v);
}
dfs(1,0);
dp(1,0);
printf("%d\n",ans);
return 0;
}
/*
5
1 2
2 3
1 4
1 5
1 2
2 3
3 4
4 5
3 6
*/