luogu P3313 [SDOI2014]旅行
动态开点线段树。
对每个宗教记一个root,然后动态开点,配合树剖即可切掉本题。
线段树大小开NlogN即可。
Code
#include<iostream> #include<cstdio> #define N 100010 using namespace std; int n,q,w[N],c[N],Head[N],nex[N*2],ver[N*2],root[N],d[N],son[N],top[N],size[N],id[N],f[N],tot,cnt,tmp,p; char op[3]; struct Segmenttree{ int lc,rc,dat,sum; }t[N*30]; void ADD(int x,int y){ver[++tot]=y;nex[tot]=Head[x];Head[x]=tot;} inline int read(){ char c=getchar();int x=0,flag=1; while(c<'0' || c>'9'){if(c=='-') flag=-1;c=getchar();} while(c>='0' && c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();} return x*flag; } int build(){cnt++;t[cnt].lc=t[cnt].rc=t[cnt].dat=t[cnt].sum=0;return cnt;} void dfs(int x,int fa,int depth){ size[x]=1;d[x]=depth;f[x]=fa;int maxson=-1; for(int i=Head[x];i;i=nex[i]){ int y=ver[i];if(y==fa) continue; dfs(y,x,depth+1);size[x]+=size[y];if(size[y]>maxson) maxson=size[y],son[x]=y; } } void dfs2(int x,int topf){ id[x]=++tmp;top[x]=topf;if(!son[x]) return;dfs2(son[x],topf); for(int i=Head[x];i;i=nex[i]){ int y=ver[i];if(y==f[x] || y==son[x]) continue;dfs2(y,y); } } void add(int p,int l,int r,int x,int k){ if(l==r){t[p].dat=t[p].sum=k;return;} int mid=(l+r)>>1; if(x<=mid){if(!t[p].lc) t[p].lc=build();add(t[p].lc,l,mid,x,k);} else{if(!t[p].rc) t[p].rc=build();add(t[p].rc,mid+1,r,x,k);} t[p].dat=max(t[t[p].lc].dat,t[t[p].rc].dat);t[p].sum=t[t[p].lc].sum+t[t[p].rc].sum; } void dfs3(int x){ if(!root[c[x]]) root[c[x]]=build();add(root[c[x]],1,n,id[x],w[x]); for(int i=Head[x];i;i=nex[i]){int y=ver[i];if(y==f[x]) continue;dfs3(y);} } int Ssum(int p,int l,int r,int ls,int rs){ if(ls<=l && rs>=r) return t[p].sum; int mid=(l+r)>>1,val=0; if(ls<=mid) val+=Ssum(t[p].lc,l,mid,ls,rs); if(rs>mid) val+=Ssum(t[p].rc,mid+1,r,ls,rs); return val; } int Qsum(int x,int y){ int rt=c[x],res=0; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]]) swap(x,y); res+=Ssum(root[rt],1,n,id[top[x]],id[x]);x=f[top[x]]; } if(d[x]<d[y]) swap(x,y);res+=Ssum(root[rt],1,n,id[y],id[x]);return res; } int Smax(int p,int l,int r,int ls,int rs){ if(ls<=l && rs>=r) return t[p].dat; int mid=(l+r)>>1,val=0xcfcfcfcf; if(ls<=mid) val=max(val,Smax(t[p].lc,l,mid,ls,rs)); if(rs>mid) val=max(val,Smax(t[p].rc,mid+1,r,ls,rs)); return val; } int Qmax(int x,int y){ int rt=c[x],res=0; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]]) swap(x,y); res=max(res,Smax(root[rt],1,n,id[top[x]],id[x]));x=f[top[x]]; } if(d[x]<d[y]) swap(x,y);res=max(res,Smax(root[rt],1,n,id[y],id[x]));return res; } int main(){ n=read();q=read(); for(int i=1;i<=n;i++) w[i]=read(),c[i]=read(); for(int i=1;i<n;i++){int u,v;u=read();v=read();ADD(u,v);ADD(v,u);} dfs(1,0,1);dfs2(1,1);dfs3(1); while(q--){ scanf("%s",op);int x=read(),y=read(); if(op[1]=='C'){add(root[c[x]],1,n,id[x],0);c[x]=y;add(root[c[x]],1,n,id[x],w[x]);} if(op[1]=='W'){w[x]=y;add(root[c[x]],1,n,id[x],w[x]);} if(op[1]=='S'){printf("%d\n",Qsum(x,y));} if(op[1]=='M'){printf("%d\n",Qmax(x,y));} } return 0; }