独特的树叶
独特的树叶
JYY有两棵树A和B:树A有N个点,编号为1到N;树B有N+1个点,编号为1到N+1。JYY知道树B恰好是由树A加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树B中的哪一个叶节点呢?
n 1e5
sol
可以先把A树每一个点为根的树哈希值求出来,然后再把B树的每个叶子依次删去,判断即可。
两者都用换根实现。
有个知识:树哈希
注意到子树顺序是无关的,那么可以排序完再哈起来。
当然有个很妙的做法:异或。(换根也很方便!)
注意要加sz,否则容易被两棵哈希值一样的子树卡掉。
注意 ll tmp+=A*B 可能会溢出!
//而 tmp=tmp+A*B 似乎不会
1 #include<cstdio> 2 #include<iostream> 3 #include<cstdlib> 4 #include<cstring> 5 #include<algorithm> 6 #include<cmath> 7 #include<map> 8 #define ll unsigned long long 9 #define maxn 100005 10 #define p 1000000007 11 #define q 19990213 12 #define o 19260817 13 using namespace std; 14 int n,head[maxn],tot; 15 int sz[maxn],in[maxn],ans=1e9; 16 ll h[maxn],ha[maxn]; 17 map<ll,bool>f; 18 struct node{ 19 int v,nex; 20 }e[maxn*2]; 21 void add(int t1,int t2){ 22 e[++tot].v=t2;e[tot].nex=head[t1];head[t1]=tot; 23 } 24 void dfs(int k,int fa){ 25 sz[k]=1;h[k]=0; 26 for(int i=head[k];i;i=e[i].nex){ 27 if(e[i].v==fa)continue; 28 dfs(e[i].v,k); 29 sz[k]+=sz[e[i].v];h[k]^=(h[e[i].v]+o); 30 } 31 h[k]=h[k]+sz[k]*q+p; 32 } 33 void ch(int k,int fa){ 34 f[h[k]]=1; 35 for(int i=head[k];i;i=e[i].nex){ 36 int v=e[i].v; 37 if(v==fa)continue; 38 ll t1=h[k],t2=h[v]; 39 h[k]=h[k]-n*q-p;h[k]^=(h[v]+o); 40 h[v]=h[v]-sz[v]*q-p; 41 42 sz[k]=n-sz[v];sz[v]=n; 43 44 h[k]=h[k]+sz[k]*q+p; 45 h[v]^=(h[k]+o);h[v]=h[v]+n*q+p; 46 47 ch(v,k); 48 49 h[k]=t1,h[v]=t2;sz[v]=n-sz[k],sz[k]=n; 50 } 51 } 52 void work(int k,int fa){ 53 for(int i=head[k];i;i=e[i].nex){ 54 int v=e[i].v; 55 if(v==fa)continue; 56 if(in[v]>1){ 57 ll t1=h[k],t2=h[v]; 58 h[k]=h[k]-n*q-p;h[k]^=(h[v]+o); 59 h[v]=h[v]-sz[v]*q-p; 60 61 sz[k]=n-sz[v];sz[v]=n; 62 63 h[k]=h[k]+sz[k]*q+p; 64 h[v]^=(h[k]+o);h[v]=h[v]+n*q+p; 65 66 work(v,k); 67 68 h[k]=t1,h[v]=t2;sz[v]=n-sz[k],sz[k]=n; 69 70 } 71 else { 72 ll tmp=h[k]-n*q-p;tmp^=(h[v]+o); 73 tmp+=(ll)(n-1)*q+p; 74 if(f[tmp])ans=min(ans,v); 75 } 76 } 77 } 78 int main() 79 { 80 cin>>n; 81 for(int i=1,t1,t2;i<n;i++){ 82 scanf("%d%d",&t1,&t2); 83 add(t1,t2);add(t2,t1); 84 } 85 dfs(1,0);ch(1,0); 86 tot=0;memset(head,0,sizeof head); 87 n++;for(int i=1,t1,t2;i<n;i++){ 88 scanf("%d%d",&t1,&t2); 89 add(t1,t2);add(t2,t1); 90 in[t1]++;in[t2]++; 91 } 92 for(int i=1;i<=n;i++){ 93 if(in[i]>1){ 94 dfs(i,0);work(i,0); 95 break; 96 } 97 } 98 cout<<ans<<endl; 99 return 0; 100 }