通过这题我知道了一个鬼故事,trunc(ln(128)/ln(2))=6……以后不敢轻易这么写了
好了言归正传,这题明显的构建虚树,但构建虚树后怎么树形dp呢?
由于虚树上的点不仅是议事会还有可能是议事会的LCA,所以
我们要先求出虚树上每个点是被那个议事会管理的,这我们可以通过两遍dfs求出(儿子更新父亲,父亲更新儿子)
然后我们考虑虚树上每条边所代表原数的结点归属就可以了,这个地方细节比较多,建议自己想,具体见代码注释
1 type node=record 2 po,next:longint; 3 end; 4 point=record 5 fr,ds:longint; 6 end; 7 8 var e:array[0..600010] of node; 9 f,b,c,d,a,st,ans,p,s:array[0..300010] of longint; 10 anc:array[0..300010,0..20] of longint; 11 w:array[0..300010] of point; 12 n,q,m,i,len,j,x,y,t,z:longint; 13 v:array[0..300010] of boolean; 14 15 procedure swap(var a,b:longint); 16 var c:longint; 17 begin 18 c:=a; 19 a:=b; 20 b:=c; 21 end; 22 23 procedure add(x,y:longint); 24 begin 25 inc(len); 26 e[len].po:=y; 27 e[len].next:=p[x]; 28 p[x]:=len; 29 end; 30 31 procedure dfs(x:longint); 32 var i,y:longint; 33 begin 34 inc(t); 35 a[x]:=t; 36 s[x]:=1; 37 i:=p[x]; 38 while i<>0 do 39 begin 40 y:=e[i].po; 41 if s[y]=0 then 42 begin 43 anc[y,0]:=x; 44 d[y]:=d[x]+1; 45 dfs(y); 46 s[x]:=s[x]+s[y]; 47 end; 48 i:=e[i].next; 49 end; 50 end; 51 52 procedure sort(l,r:longint); 53 var i,j,x:longint; 54 begin 55 i:=l; 56 j:=r; 57 x:=c[(l+r) shr 1]; 58 repeat 59 while a[c[i]]<a[x] do inc(i); 60 while a[x]<a[c[j]] do dec(j); 61 if not(i>j) then 62 begin 63 swap(c[i],c[j]); 64 inc(i); 65 dec(j); 66 end; 67 until i>j; 68 if l<j then sort(l,j); 69 if i<r then sort(i,r); 70 end; 71 72 function lca(x,y:longint):longint; 73 var i:longint; 74 begin 75 if x=y then exit(x); 76 if d[x]<d[y] then swap(x,y); 77 if d[x]>d[y] then 78 begin 79 for i:=18 downto 0 do 80 if d[x]-1 shl i>=d[y] then x:=anc[x,i]; 81 end; 82 if x=y then exit(x); 83 for i:=18 downto 0 do 84 if anc[x,i]<>anc[y,i] then 85 begin 86 x:=anc[x,i]; 87 y:=anc[y,i]; 88 end; 89 exit(anc[x,0]); 90 end; 91 92 procedure get(var a:point;x,y:longint); 93 begin 94 if (a.ds>x) or (a.ds=x) and (a.fr>y) then 95 begin 96 a.ds:=x; 97 a.fr:=y; 98 end; 99 end; 100 101 function find(x,h:longint):longint; 102 var i:longint; 103 begin 104 if h=0 then exit(x); 105 for i:=18 downto 0 do 106 if h-1 shl i>=0 then 107 begin 108 x:=anc[x,i]; 109 h:=h-1 shl i; 110 if h=0 then break; 111 end; 112 exit(x); 113 end; 114 115 procedure work1(x:longint); 116 var i,y:longint; 117 begin 118 if v[x] then 119 begin 120 w[x].fr:=x; 121 w[x].ds:=0; 122 end 123 else begin 124 w[x].fr:=n+1; 125 w[x].ds:=10000010; 126 end; 127 f[x]:=s[x]; 128 i:=p[x]; 129 while i<>0 do 130 begin 131 y:=e[i].po; 132 f[x]:=f[x]-s[find(y,d[y]-d[x]-1)]; 133 work1(y); 134 get(w[x],w[y].ds+d[y]-d[x],w[y].fr); 135 i:=e[i].next; 136 end; 137 end; 138 139 procedure work2(x:longint); 140 var i,y:longint; 141 begin 142 i:=p[x]; 143 while i<>0 do 144 begin 145 y:=e[i].po; 146 get(w[y],w[x].ds+d[y]-d[x],w[x].fr); 147 work2(y); 148 i:=e[i].next; 149 end; 150 end; 151 152 procedure calc(x:longint); 153 var i,y,l,h:longint; 154 begin 155 inc(ans[b[w[x].fr]],f[x]); //我们先单独考虑边的端点 156 i:=p[x]; 157 while i<>0 do 158 begin 159 y:=e[i].po; 160 if w[x].fr=w[y].fr then 161 inc(ans[b[w[x].fr]],s[find(y,d[y]-d[x]-1)]-s[y]) 162 else begin 163 l:=w[x].ds+w[y].ds+d[y]-d[x]; 164 h:=l div 2-w[y].ds; //均分 165 if (l mod 2=0) and (w[x].fr<w[y].fr) then dec(h); //注意临界情况 166 h:=find(y,h); //寻找向上d个单位的点 167 inc(ans[b[w[x].fr]],s[find(y,d[y]-d[x]-1)]-s[h]); //注意这里的结点归属 168 inc(ans[b[w[y].fr]],s[h]-s[y]); 169 end; 170 calc(y); 171 i:=e[i].next; 172 end; 173 p[x]:=0; 174 end; 175 176 begin 177 readln(n); 178 for i:=1 to n-1 do 179 begin 180 readln(x,y); 181 add(x,y); 182 add(y,x); 183 end; 184 dfs(1); 185 for j:=1 to trunc(ln(n)/ln(2)) do 186 for i:=1 to n do 187 begin 188 x:=anc[i,j-1]; 189 anc[i,j]:=anc[x,j-1]; 190 end; 191 192 len:=0; 193 fillchar(p,sizeof(p),0); 194 readln(m); 195 while m>0 do 196 begin 197 dec(m); 198 len:=0; 199 readln(q); 200 for i:=1 to q do 201 begin 202 read(c[i]); 203 b[c[i]]:=i; 204 v[c[i]]:=true; 205 end; 206 sort(1,q); 207 st[1]:=1; 208 t:=1; 209 for i:=1 to q do 210 begin 211 x:=c[i]; 212 z:=lca(x,st[t]); 213 while d[z]<d[st[t]] do 214 begin 215 if d[z]>=d[st[t-1]] then 216 begin 217 add(z,st[t]); 218 dec(t); 219 if st[t]<>z then 220 begin 221 inc(t); 222 st[t]:=z; 223 end; 224 break; 225 end; 226 add(st[t-1],st[t]); 227 dec(t); 228 end; 229 if st[t]<>x then 230 begin 231 inc(t); 232 st[t]:=x; 233 end; 234 end; 235 while t>1 do 236 begin 237 add(st[t-1],st[t]); 238 dec(t); 239 end; 240 work1(1); 241 work2(1); 242 calc(1); 243 for i:=1 to q do 244 begin 245 write(ans[i],' '); 246 ans[i]:=0; 247 b[c[i]]:=0; 248 v[c[i]]:=false; 249 end; 250 writeln; 251 end; 252 end.