题解:
树链剖分模板
注意最小值一开始是-1e9(错了n次)
代码:
#include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; const int N=400005; char s[100]; int tt[N],fp[N],e[N][3],f[25][N],jin[N],out[N],l,sum[N],val[N],n,x,y,m; int data[N],ne[N],num[N],son[N],top[N],tot,fi[N],zz[N],pos,deep[N],p[N],fa[N]; void jb(int x,int y) { ne[++tot]=fi[x]; fi[x]=tot; zz[tot]=y; } void dfs1(int x,int y,int z) { jin[x]=++l; deep[x]=z; num[x]=1; fa[x]=f[0][x]=y; for (int i=fi[x];i;i=ne[i]) { int k=zz[i]; if (k!=y) { dfs1(k,x,z+1); num[x]+=num[k]; if (son[x]==-1||(num[son[x]]<num[k]))son[x]=k; } } out[x]=++l; } void dfs2(int x,int y) { top[x]=y; if (son[x]!=-1) { p[x]=pos++; fp[p[x]]=x; dfs2(son[x],y); } else { p[x]=pos++; fp[p[x]]=x; return; } for (int i=fi[x];i;i=ne[i]) if (fa[x]!=zz[i]&&zz[i]!=son[x]) dfs2(zz[i],zz[i]); } void pushup(int x) { data[x]=max(data[x*2],data[x*2+1]); sum[x]=sum[x*2]+sum[x*2+1]; } void build(int l,int r,int x) { if (l==r) { data[x]=val[l]; sum[x]=val[l]; return; } int mid=(l+r)/2; build(l,mid,x*2); build(mid+1,r,x*2+1); pushup(x); } int query1(int x,int y,int l,int r,int s) { if (x>r||y<l)return -1e9; if (x<=l&&y>=r)return data[s]; int mid=(l+r)/2; return max(query1(x,y,l,mid,s*2),query1(x,y,mid+1,r,s*2+1)); } int query2(int x,int y,int l,int r,int s) { if (x>r||y<l)return 0; if (x<=l&&y>=r)return sum[s]; int mid=(l+r)/2; return query2(x,y,l,mid,s*2)+query2(x,y,mid+1,r,s*2+1); } void change(int p,int q,int l,int r,int x) { if (l==r) { data[x]=sum[x]=q; return; } int mid=(l+r)/2; if (p<=mid)change(p,q,l,mid,x*2); else change(p,q,mid+1,r,x*2+1); pushup(x); } int find1(int x,int y) { int temp=-1e9,f1=top[x],f2=top[y]; while (f1!=f2) { if (deep[f1]<deep[f2]) { swap(f1,f2); swap(x,y); } temp=max(temp,query1(p[f1],p[x],1,n,1)); x=fa[f1],f1=top[x]; } if (p[x]>p[y])swap(x,y); return max(temp,query1(p[x],p[y],1,n,1)); } int find2(int x,int y) { int temp=0,f1=top[x],f2=top[y]; while (f1!=f2) { if (deep[f1]<deep[f2]) { swap(f1,f2); swap(x,y); } temp+=query2(p[f1],p[x],1,n,1); x=fa[f1],f1=top[x]; } if (p[x]>p[y])swap(x,y); return temp+query2(p[x],p[y],1,n,1); } int main() { scanf("%d",&n); pos=1;tot=0;out[0]=2*n+1; memset(son,-1,sizeof son); for (int i=1;i<n;i++) { scanf("%d%d",&x,&y); jb(x,y);jb(y,x); } dfs1(1,0,0); dfs2(1,1); for (int i=1;i<=n;i++) { scanf("%d",&tt[i]); val[p[i]]=tt[i]; } build(1,n,1); scanf("%d",&m); while (m--) { scanf("%s%d%d",&s,&x,&y); if (s[0]=='Q'&&s[1]=='M')printf("%d\n",find1(x,y)); if (s[0]=='Q'&&s[1]=='S')printf("%d\n",find2(x,y)); if (s[0]=='C')change(p[x],y,1,n,1); } return 0; }