2020 hdu多校赛 第三场 1003 Tokitsukaze and Colorful Tree
题意:
给你一棵树(n<=1e5),每个节点有颜色col[i]<=n,和权值val[i]<2^20,
每次修改一个节点的权值或颜色,求每次修改之后每个节点与不是他的祖先或在他子树内的且颜色相同的点的权值异或之和。
首先,我们考虑简化问题,如果没有颜色、祖先和子树限制,单纯求各个点对的异或值之和应该怎么求?
很简单,我们设sum[i]为二进制第 i 位为1的数有多少个,然后枚举每个点的权值在着一位为1还是0即可。
那么加上颜色限制呢?
设sum[x][i]为颜色x的点里面二进制第 i 位为 1 的数有多少个,其他不变。
让我们再加上祖先限制。
此时,一个点会对他子树内所有点产生影响,我们就可以考虑在 dfs 序上用线段树差分维护sum[x][i],对 x 点查询的时候直接看1~x的区间和就可以知道x的祖先的影响了。
最后,让我们加上子树限制。
与上文同理,我们用dfs序维护,直接区间求和即可。
(出题人竟然卡我内存,QAQ,赌了一把改成short才过,不然不是MLE就是RE
1 #include<iostream> 2 #include<cstdlib> 3 #include<cstring> 4 #include<algorithm> 5 #include<cmath> 6 #include<cstdio> 7 #define N 100005 8 using namespace std; 9 int T,n,zz,a[N]; 10 struct ro{ 11 int to,next; 12 }road[N*2]; 13 void build(int x,int y) 14 { 15 zz++; 16 road[zz].to=y; 17 road[zz].next=a[x]; 18 a[x]=zz; 19 } 20 int fa[N],l[N],r[N],dfn[N],zz1,dl[N]; 21 void dfs(int x) 22 { 23 zz1++; 24 dfn[x]=zz1; 25 dl[zz1]=x; 26 l[x]=zz1; 27 for(int i=a[x];i;i=road[i].next) 28 { 29 int y=road[i].to; 30 if(y==fa[x])continue; 31 fa[y]=x; 32 dfs(y); 33 } 34 r[x]=zz1; 35 } 36 int val[N],col[N],zz2; 37 struct no{ 38 int l,r,size[2]; 39 short da[2][20]; 40 }node[int(N*30)]; 41 int A[N],sm[2]; 42 void add(int l,int r,int x,int to,int op,int dat) 43 { 44 if(l==r) 45 { 46 node[x].size[op]+=dat; 47 for(int i=0;i<20;i++) 48 { 49 node[x].da[op][i]+=A[i]; 50 } 51 return; 52 } 53 int mid=(l+r)>>1; 54 if(to>mid) 55 { 56 if(!node[x].r) 57 { 58 zz2++; 59 node[zz2].l=node[zz2].r=0; 60 node[zz2].size[0]=node[zz2].size[1]=0; 61 memset(node[zz2].da,0,sizeof(node[zz2].da)); 62 node[x].r=zz2; 63 } 64 add(mid+1,r,node[x].r,to,op,dat); 65 } 66 else 67 { 68 if(!node[x].l) 69 { 70 zz2++; 71 node[zz2].l=node[zz2].r=0; 72 node[zz2].size[0]=node[zz2].size[1]=0; 73 memset(node[zz2].da,0,sizeof(node[zz2].da)); 74 node[x].l=zz2; 75 } 76 add(l,mid,node[x].l,to,op,dat); 77 } 78 node[x].size[op]=node[node[x].l].size[op]+node[node[x].r].size[op]; 79 for(int i=0;i<20;i++) 80 { 81 node[x].da[op][i]=node[node[x].l].da[op][i]+node[node[x].r].da[op][i]; 82 } 83 } 84 void get(int l,int r,int left,int right,int x,int op,int op2) 85 { 86 if(!x)return ; 87 if(left>right)return; 88 if(l==left&&r==right) 89 { 90 sm[op]+=node[x].size[op]*op2; 91 for(int i=0;i<20;i++) 92 { 93 A[i]+=node[x].da[op][i]*op2; 94 } 95 return; 96 } 97 int mid=(l+r)>>1; 98 if(left>mid) 99 { 100 get(mid+1,r,left,right,node[x].r,op,op2); 101 } 102 else if(right<=mid) 103 { 104 get(l,mid,left,right,node[x].l,op,op2); 105 } 106 else 107 { 108 get(l,mid,left,mid,node[x].l,op,op2); 109 get(mid+1,r,mid+1,right,node[x].r,op,op2); 110 } 111 } 112 int cnt[N],root[N],sum[N][20]; 113 void del(int x) 114 { 115 for(int i=0;i<20;i++) 116 { 117 if(val[x]&(1<<i)) A[i]=-1,sum[col[x]][i]--; 118 else A[i]=0; 119 } 120 add(1,n+1,root[col[x]],dfn[x],0,-1); 121 add(1,n+1,root[col[x]],l[x],1,-1); 122 for(int i=0;i<20;i++) 123 { 124 if(val[x]&(1<<i)) A[i]=1; 125 else A[i]=0; 126 } 127 add(1,n+1,root[col[x]],r[x]+1,1,1); 128 cnt[col[x]]--; 129 } 130 long long work(int x) 131 { 132 for(int i=0;i<20;i++) A[i]=0; 133 sm[0]=sm[1]=0; 134 get(1,n+1,l[x],r[x],root[col[x]],0,1); 135 get(1,n+1,1,dfn[x],root[col[x]],1,1); 136 long long ans=0; 137 for(int i=0;i<20;i++) 138 { 139 if(val[x]&(1<<i)) 140 { 141 ans+=1ll*((cnt[col[x]]-sum[col[x]][i])-(sm[0]+sm[1]-A[i]))*(1<<i); 142 } 143 else 144 { 145 ans+=1ll*(sum[col[x]][i]-A[i])*(1<<i); 146 } 147 148 } 149 return ans; 150 } 151 void ins(int x) 152 { 153 if(!root[col[x]]) 154 { 155 zz2++; 156 root[col[x]]=zz2; 157 node[zz2].l=node[zz2].r=0; 158 node[zz2].size[0]=node[zz2].size[1]=0; 159 memset(node[zz2].da,0,sizeof(node[zz2].da)); 160 } 161 for(int i=0;i<20;i++) 162 { 163 if(val[x]&(1<<i)) A[i]=1,sum[col[x]][i]++; 164 else A[i]=0; 165 } 166 add(1,n+1,root[col[x]],dfn[x],0,1); 167 add(1,n+1,root[col[x]],l[x],1,1); 168 for(int i=0;i<20;i++) 169 { 170 if(val[x]&(1<<i)) A[i]=-1; 171 else A[i]=0; 172 } 173 add(1,n+1,root[col[x]],r[x]+1,1,-1); 174 cnt[col[x]]++; 175 } 176 int main() 177 { 178 // freopen("1003.in","r",stdin); 179 // freopen("1.out","w",stdout); 180 scanf("%d",&T); 181 while(T--) 182 { 183 scanf("%d",&n); 184 for(int i=1;i<=n;i++) scanf("%d",&col[i]); 185 for(int i=1;i<=n;i++) scanf("%d",&val[i]); 186 zz=0; 187 memset(a,0,sizeof(a)); 188 for(int i=1;i<n;i++) 189 { 190 int x,y; 191 scanf("%d%d",&x,&y); 192 build(x,y); 193 build(y,x); 194 } 195 memset(fa,0,sizeof(fa)); 196 memset(dl,0,sizeof(dl)); 197 memset(l,0,sizeof(l)); 198 memset(r,0,sizeof(r)); 199 memset(cnt,0,sizeof(cnt)); 200 memset(sum,0,sizeof(sum)); 201 memset(root,0,sizeof(root)); 202 zz1=0; 203 dfs(1); 204 205 zz2=0; 206 long long ans=0; 207 for(int i=1;i<=n;i++) 208 { 209 if(!root[col[i]]) 210 { 211 zz2++; 212 root[col[i]]=zz2; 213 node[zz2].l=node[zz2].r=0; 214 node[zz2].size[0]=node[zz2].size[1]=0; 215 memset(node[zz2].da,0,sizeof(node[zz2].da)); 216 } 217 ans+=work(i); 218 ins(i); 219 } 220 printf("%lld\n",ans); 221 int q; 222 scanf("%d",&q); 223 for(int i=1;i<=q;i++) 224 { 225 int op,x,y; 226 scanf("%d%d%d",&op,&x,&y); 227 del(x); 228 ans-=work(x); 229 // cout<<i<<' '<<ans<<endl; 230 if(op==1) val[x]=y; 231 else col[x]=y; 232 ans+=work(x); 233 ins(x); 234 printf("%lld\n",ans); 235 } 236 } 237 return 0; 238 }