由于查询的是树链的并的信息,同时信息不能高效合并,只能考虑用bitset维护,小范围暴力预处理以便从bitset算出答案
对树分块,保证每块是连通的且直径较小,对分出的块缩点建新树,在新树上建树上ST表,用bitset保存信息,于是每条链只需用4个bitset合并,再暴力加上零散部分
考虑到bitset的复杂度非常高,分块部分有很大的调整空间,不会成为瓶颈
#include<cstdio> #include<cstring> const int M=110007,N=946,N4=N*4,B=50; char buf[M*100],*ptr=buf-1; int _(){ int x=0,c=*++ptr; while(c<48)c=*++ptr; while(c>47)x=x*10+c-48,c=*++ptr; return x; } typedef unsigned int bits[N]; void _or(bits a,bits b){for(int i=0;i<N;i+=2)b[i]|=a[i],b[i+1]|=a[i+1];} void _set(bits a,int x){a[x>>5]|=1<<x;} int n,m,la=0,es[2*M],enx[2*M],e0[M],ep=2,v[M]; int fa[M],sz[M],top[M],dep[M],son[M]; bits ans; void f1(int w,int pa){ dep[w]=dep[fa[w]=pa]+1; sz[w]=1; for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=pa){ f1(u,w); sz[w]+=sz[u]; if(sz[u]>sz[son[w]])son[w]=u; } } } void f2(int w,int tp){ top[w]=tp; if(son[w])f2(son[w],tp); for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=fa[w]&&u!=son[w])f2(u,u); } } int lca(int x,int y){ int a=top[x],b=top[y]; while(a!=b){ if(dep[a]>dep[b])x=fa[a],a=top[x]; else y=fa[b],b=top[y]; } return dep[x]<dep[y]?x:y; } void maxs(int&a,int b){if(a<b)a=b;} int id[M],idp=0,md[M],rt[M],fas[M/B][19],deps[M],e1[M]; bits st[M/B][19]; void f5(int w){ for(int i=e1[w];i;i=enx[i]){ int u=es[i]; deps[u]=deps[w]+1; f5(u); } } void f4(int w){ id[w]=idp; for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=fa[w]&&!id[u])f4(u); } } void f3(int w){ for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(u!=fa[w]){ f3(u); if(!id[u])maxs(md[w],md[u]+1); } } if(w==1||md[w]==B){ rt[++idp]=w; f4(w); } } int log_2[M]; void cal(int x,int y){ int D=deps[id[x]]-deps[id[y]]-1; if(D<=0){ _set(ans,v[y]); while(x!=y)_set(ans,v[x]),x=fa[x]; return; } int a=rt[id[x]]; while(x!=a)_set(ans,v[x]),x=fa[x]; x=id[x]; int d=log_2[D]; _or(st[x][d],ans); D-=1<<d; if(D){ for(int i=0;i<19;++i)if(D>>i&1)x=fas[x][i]; _or(st[x][d],ans); } x=rt[fas[x][d]]; while(x!=y)x=fa[x],_set(ans,v[x]); } int ls[277],rs[277],ms[33][277]; int pw[33][30007]; int xs[N4+5]; int _cal(bits x,int k){ int s=0,xp=0; for(int i=0;i<N;++i){ xs[xp]=x[i]&255; xs[xp+1]=x[i]>>8&255; xs[xp+2]=x[i]>>16&255; xs[xp+3]=x[i]>>24&255; xp+=4; } for(int i=0,d=0;i<N4;++i){ while(xs[i]==255)++i,d+=8; s+=pw[k][ls[xs[i]]+d]; s+=ms[k][xs[i]]; d=rs[xs[i]]; } return s; } int main(){ fread(buf,1,sizeof(buf),stdin)[buf]=0; for(int i=1;i<=30000;++i)pw[0][i]=1; for(int t=1;t<=30;++t) for(int i=1;i<=30000;++i)pw[t][i]=pw[t-1][i]*i; for(int i=1;i<255;++i){ int vs[20],vp=0; for(int a=0,b=0;a<8;a=b){ b=a; if(~i>>a&1){ ++b; continue; } while(i>>b&1)++b; vs[vp++]=b-a; } if(i&1)ls[i]=vs[0]; if(i>>7&1)rs[i]=vs[vp-1]; for(int j=(i&1);j<vp-(i>>7&1);++j){ for(int k=0;k<=30;++k)ms[k][i]+=pw[k][vs[j]]; } } n=_();m=_(); for(int i=1;i<=n;++i)v[i]=_(); for(int i=1,a,b;i<n;++i){ a=_();b=_(); es[ep]=b;enx[ep]=e0[a];e0[a]=ep++; es[ep]=a;enx[ep]=e0[b];e0[b]=ep++; } f1(1,0);f2(1,1); f3(1); for(int i=1;i<=idp;++i){ int w=rt[i],u=fa[rt[fas[i][0]=id[fa[rt[i]]]]]; while(w!=u)_set(st[i][0],v[w]),w=fa[w]; } for(int i=1;i<19;++i){ for(int j=1;j<=idp;++j){ int k=fas[j][i-1]; if(fas[j][i]=fas[k][i-1]){ memcpy(st[j][i],st[j][i-1],sizeof(bits)); _or(st[k][i-1],st[j][i]); } } } log_2[0]=-1; for(int i=1;i<=idp;++i){ log_2[i]=log_2[i>>1]+1; if(fas[i][0]){ int a=fas[i][0]; es[ep]=i;enx[ep]=e1[a];e1[a]=ep++; } } f5(idp); while(m--){ memset(ans,0,sizeof(ans)); for(int c=_();c;--c){ int x=_()^la,y=_()^la,z=lca(x,y); cal(x,z); cal(y,z); } printf("%u\n",la=_cal(ans,_())); } return 0; }