题解:

树链剖分模板

注意最小值一开始是-1e9(错了n次)

代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int N=400005;
char s[100];
int tt[N],fp[N],e[N][3],f[25][N],jin[N],out[N],l,sum[N],val[N],n,x,y,m;
int data[N],ne[N],num[N],son[N],top[N],tot,fi[N],zz[N],pos,deep[N],p[N],fa[N];
void jb(int x,int y)
{
    ne[++tot]=fi[x];
    fi[x]=tot;
    zz[tot]=y;
}
void dfs1(int x,int y,int z)
{
    jin[x]=++l;
    deep[x]=z;
    num[x]=1;
    fa[x]=f[0][x]=y;
    for (int i=fi[x];i;i=ne[i])
     {
         int k=zz[i];
         if (k!=y)
          {
              dfs1(k,x,z+1);
              num[x]+=num[k];
              if (son[x]==-1||(num[son[x]]<num[k]))son[x]=k;
          }
     }
    out[x]=++l; 
}
void dfs2(int x,int y)
{
    top[x]=y;
    if (son[x]!=-1)
     {
         p[x]=pos++;
         fp[p[x]]=x;
         dfs2(son[x],y);
     }
    else
     {
         p[x]=pos++;
         fp[p[x]]=x;
         return;
     } 
    for (int i=fi[x];i;i=ne[i])
     if (fa[x]!=zz[i]&&zz[i]!=son[x])
      dfs2(zz[i],zz[i]); 
}
void pushup(int x)
{
    data[x]=max(data[x*2],data[x*2+1]);
    sum[x]=sum[x*2]+sum[x*2+1];
}
void build(int l,int r,int x)
{
    if (l==r)
     {
         data[x]=val[l];
         sum[x]=val[l];
         return;
     }
    int mid=(l+r)/2;
    build(l,mid,x*2);
    build(mid+1,r,x*2+1);
    pushup(x); 
}
int query1(int x,int y,int l,int r,int s)
{
    if (x>r||y<l)return -1e9;
    if (x<=l&&y>=r)return data[s];
    int mid=(l+r)/2;
    return max(query1(x,y,l,mid,s*2),query1(x,y,mid+1,r,s*2+1));
}
int query2(int x,int y,int l,int r,int s)
{
    if (x>r||y<l)return 0;
    if (x<=l&&y>=r)return sum[s];
    int mid=(l+r)/2;
    return query2(x,y,l,mid,s*2)+query2(x,y,mid+1,r,s*2+1);
}
void change(int p,int q,int l,int r,int x)
{
    if (l==r)
     {
         data[x]=sum[x]=q;
         return;
     }
    int mid=(l+r)/2;
    if (p<=mid)change(p,q,l,mid,x*2);
    else change(p,q,mid+1,r,x*2+1);
    pushup(x);
}
int find1(int x,int y)
{
    int temp=-1e9,f1=top[x],f2=top[y];
    while (f1!=f2)
     {
         if (deep[f1]<deep[f2])
          {
              swap(f1,f2);
              swap(x,y);
          }
         temp=max(temp,query1(p[f1],p[x],1,n,1));
        x=fa[f1],f1=top[x]; 
     }
    if (p[x]>p[y])swap(x,y);
    return max(temp,query1(p[x],p[y],1,n,1));
}
int find2(int x,int y)
{
    int temp=0,f1=top[x],f2=top[y];
    while (f1!=f2)
     {
         if (deep[f1]<deep[f2])
          {
              swap(f1,f2);
              swap(x,y);
          }
         temp+=query2(p[f1],p[x],1,n,1);
        x=fa[f1],f1=top[x]; 
     }
    if (p[x]>p[y])swap(x,y);
    return temp+query2(p[x],p[y],1,n,1);
}
int main()
{
    scanf("%d",&n);
    pos=1;tot=0;out[0]=2*n+1;
    memset(son,-1,sizeof son);
    for (int i=1;i<n;i++)
     {
         scanf("%d%d",&x,&y);
         jb(x,y);jb(y,x);
     }
    dfs1(1,0,0);
    dfs2(1,1);    
    for (int i=1;i<=n;i++)
     {
         scanf("%d",&tt[i]);
         val[p[i]]=tt[i];
     }     
    build(1,n,1);
    scanf("%d",&m);
    while (m--)
     {
         scanf("%s%d%d",&s,&x,&y);
         if (s[0]=='Q'&&s[1]=='M')printf("%d\n",find1(x,y));
         if (s[0]=='Q'&&s[1]=='S')printf("%d\n",find2(x,y));
         if (s[0]=='C')change(p[x],y,1,n,1);
     } 
    return 0; 
}

 

posted on 2017-12-04 18:50  宣毅鸣  阅读(139)  评论(0编辑  收藏  举报