bzoj 1036 树链剖分+线段树 裸题
题意:中文题
思路:树链剖分裸题,线段树写得比较搓,(在线段树上修改节点u的时候应该修改u映射到线段树后的节点序号,这里wa了半年,真的是半年)
AC代码:
#include "iostream" #include "string.h" #include "stack" #include "queue" #include "string" #include "vector" #include "set" #include "map" #include "algorithm" #include "stdio.h" #include "math.h" #define ll long long #define bug cout<<"UUUUUUUU"<<endl; #define mem(a) memset(a,0,sizeof(a)) using namespace std; const int MAX=1e5+100; int son[MAX],siz[MAX],fa[MAX],de[MAX],top[MAX],tip[MAX]; struct Edge{ int to; int next; }; Edge e[MAX<<1]; int tot=1,cnt=1,head[MAX]; void add(int u, int v){ e[tot].to=v; e[tot].next=head[u]; head[u]=tot++; } void Dfs1(int u, int f){ siz[u]=1; fa[u]=f; de[u]=de[f]+1; for(int i=head[u]; i!=-1; i=e[i].next){ int v=e[i].to; if(v==fa[u]) continue; Dfs1(v,u); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void Dfs2(int u, int tp){ tip[u]=cnt++; top[u]=tp; if(son[u]) Dfs2(son[u],tp); for(int i=head[u]; i!=-1; i=e[i].next){ int v=e[i].to; if(v!=fa[u]&&son[u]!=v) Dfs2(v,v); } } int w[MAX<<2],sum[MAX<<2],m[MAX]; void push_up(int rt){ m[rt]=max(m[rt<<1],m[rt<<1|1]); sum[rt]=(sum[rt<<1]+sum[rt<<1|1]); } void Build(int rt, int l, int r){ if(l==r){ sum[rt]=m[rt]=w[l]; return; } int mid=l+r>>1; Build(rt<<1,l,mid); Build(rt<<1|1,mid+1,r); push_up(rt); } void update(int rt, int L, int R, int p,int w){ if(L==R){ sum[rt]=m[rt]=w; return; } int mid=L+R>>1; if(p<=mid) update(rt<<1,L,mid,p,w); else update(rt<<1|1,mid+1,R,p,w); push_up(rt); } int query_max(int rt,int l, int r, int L, int R){ if(l==L&&r==R) return m[rt]; int mid=L+R>>1; if(r<=mid) return query_max(rt<<1,l,r,L,mid); else if(l>mid) return query_max(rt<<1|1,l,r,mid+1,R); else return max(query_max(rt<<1,l,mid,L,mid),query_max(rt<<1|1,mid+1,r,mid+1,R)); } int query_sum(int rt, int l, int r,int L, int R){ if(l==L&&r==R) return sum[rt]; int mid=L+R>>1; if(r<=mid) return query_sum(rt<<1,l,r,L,mid); else if(l>mid) return query_sum(rt<<1|1,l,r,mid+1,R); else return query_sum(rt<<1,l,mid,L,mid)+query_sum(rt<<1|1,mid+1,r,mid+1,R); } void get_sum(int u,int v,int n){ int ans=0; while(top[u]!=top[v]){ if(de[top[v]]<de[top[u]]) swap(u,v); ans+=query_sum(1,tip[top[v]],tip[v],1,n); v=fa[top[v]]; } if(de[u]<de[v]) swap(u,v); ans+=query_sum(1,tip[v],tip[u],1,n); printf("%d\n",ans); } void get_max(int u, int v,int n){ int ans=-1<<30; while(top[u]!=top[v]){ if(de[top[v]]<de[top[u]]) swap(u,v); ans=max(ans,query_max(1,tip[top[v]],tip[v],1,n)); v=fa[top[v]]; } if(de[u]<de[v]) swap(u,v); ans=max(ans,query_max(1,tip[v],tip[u],1,n)); printf("%d\n",ans); } int main(){ int n,q,a,b; char s[15]; scanf("%d",&n); memset(head,-1,sizeof(head)); for(int i=1; i<n; ++i){ scanf("%d%d",&a,&b); add(a,b); add(b,a); } Dfs1(1,1); Dfs2(1,1); for(int i=1; i<=n; ++i) scanf("%d",&w[tip[i]]); Build(1,1,n); scanf("%d",&q); while(q--){ getchar(); scanf("%s%d%d",s,&a,&b); if(s[1]=='M') get_max(a,b,n); else if(s[1]=='S') get_sum(a,b,n); else update(1,1,n,tip[a],b); } return 0; }