P3313 [SDOI2014]旅行——树链剖分+线段树(动态开点?)
一棵树,其中的点分类,点有权值,在一条链上找到一类点中的最大值或总和;
树链剖分把树变成链;
把每个宗教单开一个线段树,维护区间总和和最大值;
宗教很多,需要动态开点;
树链剖分:
void dfs1(int x,int fa) { siz[x]=1; father[x]=fa; dep[x]=dep[fa]+1; for(int p=last[x];p;p=pre[p]) { int v=other[p]; if(v==fa) continue; dfs1(v,x); siz[x]+=siz[v]; if(siz[v]>siz[son[x]]) son[x]=v; } } void dfs2(int x,int tp) { id[x]=++cnt; top[x]=tp; if(!son[x]) return ; dfs2(son[x],tp); for(int p=last[x];p;p=pre[p]) { int v=other[p]; if(v==father[x]||v==son[x]) continue; dfs2(v,v); } } dfs1(1,0); dfs2(1,1);
然后我们将每个点扔进所属宗教的线段树里;
设c[i]为i所属宗教,root[i]为线段树的总结点(根节点),注意这里用的节点为树链剖分后的新id
线段树不必记录自己的区间大小,节点是根据当前插入节点的新id决定的,不必将所有节点都开全,因为区间里的节点不都属于此线段树;
void build(int &rt,int l,int r,int w,int pos) { if(!rt) rt=++num; //t[rt].l=l;t[rt].r=r; t[rt].ma=max(t[rt].ma,w); t[rt].sum+=w; if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) build(t[rt].l,l,mid,w,pos); else build(t[rt].r,mid+1,r,w,pos); } for(int i=1;i<=n;i++) { build(root[c[i]],1,n,w[i],id[i]); }
宗教会变,我们需要删除操作和插入操作;
删除时将他所在线段树中的节点删掉即可,插入即为建树操作;
点权值会变,我们只要把点删去,再将他作为一个新点插入即可;
求和操作:
正规树链剖分操作,将链上区间线段树求和即可,注意调用相关的线段树;
求最大值同上;
int query_tot(int rt,int lb,int rb,int l,int r) { if(r<lb||l>rb) return 0; if(r>=rb&&l<=lb) return t[rt].sum; int mid=(lb+rb)>>1; return query_tot(t[rt].l,lb,mid,l,r)+query_tot(t[rt].r,mid+1,rb,l,r); } int tree_tot(int x,int y,int c) { int ans=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); ans+=query_tot(root[c],1,n,id[top[x]],id[x]); x=father[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans+=query_tot(root[c],1,n,id[x],id[y]); return ans; } int query_ma(int rt,int lb,int rb,int l,int r) { if(r<lb||l>rb) return 0; if(r>=rb&&l<=lb) return t[rt].ma; int mid=(lb+rb)>>1; return max(query_ma(t[rt].l,lb,mid,l,r),query_ma(t[rt].r,mid+1,rb,l,r)); } int tree_ma(int x,int y,int c) { int ans=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,query_ma(root[c],1,n,id[top[x]],id[x])); x=father[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans=max(ans,query_ma(root[c],1,n,id[x],id[y])); return ans; }
总代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=1e6+10; int pre[maxn*2],last[maxn],other[maxn*2],l; int w[maxn],c[maxn]; struct node_sec { int l,r,ma,sum; }t[maxn*4]; void add(int x,int y) { l++; pre[l]=last[x]; last[x]=l; other[l]=y; } int n,m; int father[maxn]; int siz[maxn],son[maxn]; int dep[maxn]; void dfs1(int x,int fa) { siz[x]=1; father[x]=fa; dep[x]=dep[fa]+1; for(int p=last[x];p;p=pre[p]) { int v=other[p]; if(v==fa) continue; dfs1(v,x); siz[x]+=siz[v]; if(siz[v]>siz[son[x]]) son[x]=v; } } int cnt,id[maxn],top[maxn]; void dfs2(int x,int tp) { id[x]=++cnt; top[x]=tp; if(!son[x]) return ; dfs2(son[x],tp); for(int p=last[x];p;p=pre[p]) { int v=other[p]; if(v==father[x]||v==son[x]) continue; dfs2(v,v); } } int root[maxn]; int num; void build(int &rt,int l,int r,int w,int pos) { if(!rt) rt=++num; //t[rt].l=l;t[rt].r=r; t[rt].ma=max(t[rt].ma,w); t[rt].sum+=w; if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) build(t[rt].l,l,mid,w,pos); else build(t[rt].r,mid+1,r,w,pos); } char s[3]; void tree_remove(int &rt,int l,int r,int pos) { if(l==r) { t[rt].ma=0;t[rt].sum=0; return ; } int mid=(l+r)>>1; if(pos<=mid) tree_remove(t[rt].l,l,mid,pos); else tree_remove(t[rt].r,mid+1,r,pos); t[rt].ma=max(t[t[rt].l].ma,t[t[rt].r].ma); t[rt].sum=t[t[rt].l].sum+t[t[rt].r].sum; } int query_tot(int rt,int lb,int rb,int l,int r) { if(r<lb||l>rb) return 0; if(r>=rb&&l<=lb) return t[rt].sum; int mid=(lb+rb)>>1; return query_tot(t[rt].l,lb,mid,l,r)+query_tot(t[rt].r,mid+1,rb,l,r); } int tree_tot(int x,int y,int c) { int ans=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); ans+=query_tot(root[c],1,n,id[top[x]],id[x]); x=father[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans+=query_tot(root[c],1,n,id[x],id[y]); return ans; } int query_ma(int rt,int lb,int rb,int l,int r) { if(r<lb||l>rb) return 0; if(r>=rb&&l<=lb) return t[rt].ma; int mid=(lb+rb)>>1; return max(query_ma(t[rt].l,lb,mid,l,r),query_ma(t[rt].r,mid+1,rb,l,r)); } int tree_ma(int x,int y,int c) { int ans=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,query_ma(root[c],1,n,id[top[x]],id[x])); x=father[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans=max(ans,query_ma(root[c],1,n,id[x],id[y])); return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d%d",&w[i],&c[i]); } for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y);add(y,x); } dfs1(1,0); dfs2(1,1); for(int i=1;i<=n;i++) { build(root[c[i]],1,n,w[i],id[i]); } for(int i=1;i<=m;i++) { int x,y; scanf("%s",s); if(s[1]=='C') { scanf("%d%d",&x,&y); tree_remove(root[c[x]],1,n,id[x]); build(root[y],1,n,w[x],id[x]); c[x]=y; continue; } else if(s[1]=='W') { scanf("%d%d",&x,&y); tree_remove(root[c[x]],1,n,id[x]); build(root[c[x]],1,n,y,id[x]); w[x]=y; continue; } else if(s[1]=='S') { scanf("%d%d",&x,&y); printf("%d\n",tree_tot(x,y,c[x])); } else { scanf("%d%d",&x,&y); printf("%d\n",tree_ma(x,y,c[x])); } } return 0; }