[BZOJ4754][JSOI2016]独特的树叶
Description
JYY有两棵树\(A\)和\(B\):树\(A\)有\(N\)个点,编号为\(1\)到\(N\);树\(B\)有\(N+1\)个节点,编号为\(1\)到\(N+1\)。
JYY 知道树\(B\)恰好是由树\(A\)加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树\(B\)中的哪一个叶节点呢?
sol
直接树\(hash\)就好了。
但是直接\(hash\)不是\(O(n^2)\)?
所以构造一个适合倒推的\(hash\),然后做一遍换根\(dp\)求以每个点为根的\(hash\)值。
这里构造的\(hash\)函数是:
\[Hash(u)=(\sum_\otimes Hash(v)+base1)\otimes(sz_u*base2+base3)
\]
然后做换根\(dp\)就好了。
我第一遍\(base\)没设好\(WA\)成\(50\)分。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<map>
using namespace std;
int gi(){
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
#define ull unsigned long long
const int N = 2e5+5;
const ull base1 = 1406;
const ull base2 = 20020415;
const ull base3 = 20011118;
int n,to[N],nxt[N],head[N],cnt,sz[N],d[N],ans=1e9;
ull Hash[N];map<ull,bool>M;
void link(int u,int v){
to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;
}
void dfs(int u,int f){
Hash[u]=0;sz[u]=1;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f){
dfs(to[e],u);
Hash[u]^=Hash[to[e]]+base1;
sz[u]+=sz[to[e]];
}
Hash[u]+=base2*sz[u]+base3;
}
void cal1(int u,int f){
M[Hash[u]]=1;
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f){
ull tmp=((Hash[u]-base2*n-base3)^(Hash[to[e]]+base1))+base2*(n-sz[to[e]])+base3;
Hash[to[e]]=((Hash[to[e]]-base2*sz[to[e]]-base3)^(tmp+base1))+base2*n+base3;
cal1(to[e],u);
}
}
void cal2(int u,int f){
for (int e=head[u];e;e=nxt[e])
if (to[e]!=f)
if (d[to[e]]>1){
ull tmp=((Hash[u]-base2*n-base3)^(Hash[to[e]]+base1))+base2*(n-sz[to[e]])+base3;
Hash[to[e]]=((Hash[to[e]]-base2*sz[to[e]]-base3)^(tmp+base1))+base2*n+base3;
cal2(to[e],u);
}else{
ull tmp=((Hash[u]-base2*n-base3)^(Hash[to[e]]+base1))+base2*(n-1)+base3;
if (M.count(tmp)) ans=min(ans,to[e]);
}
}
int main(){
n=gi();
for (int i=1;i<n;++i){
int u=gi(),v=gi();
link(u,v),link(v,u);
}
dfs(1,0);cal1(1,0);
memset(head,0,sizeof(head));cnt=0;++n;
for (int i=1;i<n;++i){
int u=gi(),v=gi();
link(u,v),link(v,u);
++d[u],++d[v];
}
for (int i=1;i<=n;++i)
if (d[i]>1){
dfs(i,0),cal2(i,0);
break;
}
printf("%d\n",ans);return 0;
}