hdu4918 Query on the subtree
树分治,设当前树的分治中心为x,其子树分治中心为y,则设father[y]=x,分治下去则可以得到一颗重心树,而且树的深度是logn。
询问操作(x,d),只需要查询重心树上x到重心树根节点上的节点的累加和。假设当前节点是y,那么节点y可以贡献的答案是那些以y为分治中心且到y距离为d-dis(x,y)的节点的总和。当然这样可能会出现重复的情况,重复情况只会出现在包含x的那颗子树上,因此减掉即可。修改操作类似。复杂度O(nlognlogn)
代码
#include<cstdio> #include<cstring> #define N 200010 #define LL long long using namespace std; int dp,pre[N],p[N],tt[N],vis[N],father[N],s[N],tmp,m; int n,a,b,i,w[N],L,cnt,tot,len[N],Len[N],start[N],Start[N],v[N]; int deep[N],ss[N][21],fa[N]; int c[N*50]; int min(int a,int b) { if (a<b) return a;return b; } int lowbit(int x) { return x&(-x); } void cc(int x,int w,int y) { while (x<=L) { c[y+x]+=w; x+=lowbit(x); } } LL sum(int x,int y) { LL ans=0; while (x>0) { ans+=c[y+x]; x-=lowbit(x); } return ans; } void link(int x,int y) { dp++;pre[dp]=p[x];p[x]=dp;tt[dp]=y; } void gao(int x) { int i; i=p[x]; while (i) { if (tt[i]!=fa[x]) { fa[tt[i]]=x; deep[tt[i]]=deep[x]+1; gao(tt[i]); } i=pre[i]; } } int lca(int x,int y) { if(deep[x]>deep[y])x^=y^=x^=y; int i; for(i=19;i>=0;i--) { if(deep[y]-deep[x]>=(1<<i)) { y=ss[y][i]; } } if(x==y)return x; for(i=19;i>=0;i--) { if(ss[x][i]!=ss[y][i]) { x=ss[x][i]; y=ss[y][i]; } } return fa[x]; } void getroot(int x,int fa,int sum) { int i,flag=0; i=p[x];s[x]=1; while (i) { if ((!vis[tt[i]])&&(tt[i]!=fa)) { getroot(tt[i],x,sum); s[x]+=s[tt[i]]; if (s[tt[i]]>sum/2) flag=1; } i=pre[i]; } if (sum-s[x]>sum/2) flag=1; if (!flag) tmp=x; } void dfs(int x,int fa,int dis) { int i; i=p[x]; if (dis>cnt) cnt=dis; v[dis]+=w[x]; while (i) { if ((!vis[tt[i]])&&(tt[i]!=fa)) dfs(tt[i],x,dis+1); i=pre[i]; } } void clear() { int i; for (i=1;i<=cnt;i++) v[i]=0;cnt=0; } int work(int x,int fa,int sum) { int i,root,t; getroot(x,0,sum); root=tmp; father[root]=fa; i=p[root]; vis[root]=1; while (i) { if (!vis[tt[i]]) { if (s[root]>s[tt[i]]) t=work(tt[i],root,s[tt[i]]); else t=work(tt[i],root,sum-s[root]); //------dist(root,point in subtree t)-------- dfs(tt[i],0,2); Len[t]=cnt; Start[t]=tot; for (int j=1;j<=cnt;j++) { L=cnt; cc(j,v[j],Start[t]); } tot+=cnt; clear(); } i=pre[i]; } vis[root]=0; //--------dist(root,all point)---------- dfs(root,0,1); len[root]=cnt; start[root]=tot; for (i=1;i<=cnt;i++) { L=cnt; cc(i,v[i],start[root]); } tot+=cnt; clear(); return root; } LL query(int x,int d) { int y=0,z=x,t; LL ans=0; while (x) { t=lca(x,z); t=deep[x]+deep[z]-2*deep[t]; L=len[x]; ans+=sum(min(L,d-t+1),start[x]); if (y) { L=Len[y]; ans-=sum(min(L,d-t+1),Start[y]); } y=x; x=father[x]; } return ans; } void change(int x,int w) { int y=0,z=x,t; while (x) { t=lca(x,z); t=deep[x]+deep[z]-2*deep[t]; L=len[x]; cc(t+1,w,start[x]); if (y) { L=Len[y]; cc(t+1,w,Start[y]); } y=x; x=father[x]; } } int main() { while (scanf("%d%d",&n,&m)!=EOF) { dp=0;memset(p,0,sizeof(p)); for (i=1;i<=tot;i++) c[i]=0;tot=0; for (i=1;i<=n;i++) scanf("%d",&w[i]); for (i=1;i<n;i++) { scanf("%d%d",&a,&b); link(a,b); link(b,a); } gao(1); for(i=1;i<=n;i++) ss[i][0]=fa[i]; for(int h=1;h<20;h++) { for(i=1;i<=n;i++) { ss[i][h]=ss[ss[i][h-1]][h-1]; } } work(1,0,n); for (i=1;i<=m;i++) { getchar(); char ch; scanf("%c%d%d",&ch,&a,&b); if (ch=='?') printf("%I64d\n",query(a,b)); else { change(a,b-w[a]); w[a]=b; } } } }