树剖
先附上学习的博客链接
https://www.cnblogs.com/ivanovcraft/p/9019090.html
简单说,如果要你给树上的a到b点全部加上x值,想当然我们可以用树上差分,如果让你求a到b的路径总和我们可以先预处理根到所有点的距离再用lca,就可以求出答案为=disa+disb-2*dis_lca;
但是如果有个题要你完成以上两种操作,你总不能每次都更新点的值后处理出所有点的dis吧,所以我们运用树剖(本质是优化暴力)更新的时候我们将预先处理的树已经处理出dfs序,每个点都会dfs序号,所以更新我们可以用线段树,求区间由于树剖我们可以暴力往上跳(树剖有重儿子,优化了)
洛谷:https://www.luogu.org/problemnew/show/P3384
#include<iostream> #include<cstdio> #define int long long using namespace std; const int maxn=1e5+10; struct edge{ int next,to; }e[maxn*2]; struct node{ int l,r,ls,rs,sum,lazy; }a[maxn*2]; int n,m,r,rt,mod,v[maxn],head[maxn],cnt,f[maxn],d[maxn],son[maxn],size[maxn],top[maxn],id[maxn],rk[maxn]; void add(int x,int y) { e[++cnt].next=head[x]; e[cnt].to=y; head[x]=cnt; } void dfs1(int x) { size[x]=1,d[x]=d[f[x]]+1; for(int v,i=head[x];i;i=e[i].next) if((v=e[i].to)!=f[x]) { f[v]=x,dfs1(v),size[x]+=size[v]; if(size[son[x]]<size[v]) son[x]=v; } } void dfs2(int x,int tp) { top[x]=tp,id[x]=++cnt,rk[cnt]=x; if(son[x]) dfs2(son[x],tp); for(int v,i=head[x];i;i=e[i].next) if((v=e[i].to)!=f[x]&&v!=son[x]) dfs2(v,v); } inline void pushup(int x) { a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod; } void build(int l,int r,int x) { if(l==r) { a[x].sum=v[rk[l]],a[x].l=a[x].r=l; return; } int mid=l+r>>1; a[x].ls=cnt++,a[x].rs=cnt++; build(l,mid,a[x].ls),build(mid+1,r,a[x].rs); a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r; pushup(x); } inline int len(int x) { return a[x].r-a[x].l+1; } inline void pushdown(int x) { if(a[x].lazy) { int ls=a[x].ls,rs=a[x].rs,lz=a[x].lazy; (a[ls].lazy+=lz)%=mod,(a[rs].lazy+=lz)%=mod; (a[ls].sum+=lz*len(ls))%=mod,(a[rs].sum+=lz*len(rs))%=mod; a[x].lazy=0; } } void update(int l,int r,int c,int x) { if(a[x].l>=l&&a[x].r<=r) { (a[x].lazy+=c)%=mod,(a[x].sum+=len(x)*c)%=mod; return; } pushdown(x); int mid=a[x].l+a[x].r>>1; if(mid>=l) update(l,r,c,a[x].ls); if(mid<r) update(l,r,c,a[x].rs); pushup(x); } int query(int l,int r,int x) { if(a[x].l>=l&&a[x].r<=r) return a[x].sum; pushdown(x); int mid=a[x].l+a[x].r>>1,tot=0; if(mid>=l) tot+=query(l,r,a[x].ls); if(mid<r) tot+=query(l,r,a[x].rs); return tot%mod; } inline int sum(int x,int y) { int ret=0; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) swap(x,y); (ret+=query(id[top[x]],id[x],rt))%=mod; x=f[top[x]]; } if(id[x]>id[y]) swap(x,y); return (ret+query(id[x],id[y],rt))%mod; } inline void updates(int x,int y,int c) { while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) swap(x,y); update(id[top[x]],id[x],c,rt); x=f[top[x]]; } if(id[x]>id[y]) swap(x,y); update(id[x],id[y],c,rt); } signed main() { scanf("%lld%lld%lld%lld",&n,&m,&r,&mod); for(int i=1;i<=n;i++) scanf("%lld",&v[i]); for(int x,y,i=1;i<n;i++) { scanf("%lld%lld",&x,&y); add(x,y),add(y,x); } cnt=0,dfs1(r),dfs2(r,r); cnt=0,build(1,n,rt=cnt++); for(int op,x,y,k,i=1;i<=m;i++) { scanf("%lld",&op); if(op==1) { scanf("%lld%lld%lld",&x,&y,&k); updates(x,y,k); } else if(op==2) { scanf("%lld%lld",&x,&y); printf("%lld\n",sum(x,y)); } else if(op==3) { scanf("%lld%lld",&x,&y); update(id[x],id[x]+size[x]-1,y,rt); } else { scanf("%lld",&x); printf("%lld\n",query(id[x],id[x]+size[x]-1,rt)); } } return 0; }
洛谷:
P2486 [SDOI2011]染色
这题不能简单的区间加减,毕竟涉及到段与段之间是否会颜色相同;
在线段树更新就很简单了,判断这个区间的两头和要合并的那个区间相连的会不会颜色一样,颜色一样那么合并时的答案就应该-1;
然后我们在暴力跳的时候怎么办呢?
我们同样需要判断相连部分,我们让两个点一个设为左边一个设为右边,那么左边的为lx,右边的为ly,左边跳的时候,跳出来的右端点和lx比较,右边的和ly比较,lx,ly记录的是两边各自上一次跳的时候的边界值
#include<bits/stdc++.h> using namespace std; const int maxn=2e5+10; #define int long long struct edge { int nxt,to; }e[maxn*2]; struct QQ{int sum,lcol,rcol;}; struct node{ int l,r,ls,rs,sum,lazy,lcol,rcol; }a[maxn*2]; int n,m,rt,top[maxn],r,v[maxn],cnt,f[maxn],son[maxn],sz[maxn],d[maxn],head[maxn],id[maxn],rk[maxn]; void add(int x,int y) { e[++cnt].nxt=head[x]; e[cnt].to=y; head[x]=cnt; } void pushup(int x) { // a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod; a[x].lcol=a[a[x].ls].lcol; a[x].rcol=a[a[x].rs].rcol; a[x].sum=a[a[x].ls].sum+a[a[x].rs].sum; if(a[a[x].ls].rcol==a[a[x].rs].lcol) a[x].sum--; } int cal(int x){return a[x].r-a[x].l+1;} void pushdown(int x) { if(a[x].lazy!=-1) { int ls=a[x].ls,rs=a[x].rs,lz=a[x].lazy; (a[ls].lazy=lz),(a[rs].lazy=lz); (a[ls].sum=1),(a[rs].sum=1); a[ls].lcol=a[ls].rcol=a[rs].lcol=a[rs].rcol=lz; a[x].lazy=-1; } } void build(int l,int r,int x) { a[x].lazy=-1; if(l==r){ a[x].lcol=a[x].rcol=v[rk[l]],a[x].l=a[x].r=l,a[x].sum=1; return ; } int mid=l+r>>1; a[x].ls=cnt++,a[x].rs=cnt++; build(l,mid,a[x].ls),build(mid+1,r,a[x].rs); a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r; pushup(x); } void update2(int l,int r,int c,int x) { if(l<=a[x].l&&a[x].r<=r) { a[x].lazy=c,a[x].lcol=a[x].rcol=c; a[x].sum=1; return ; } pushdown(x); int mid=a[x].l+a[x].r>>1; if(l<=mid) update2(l,r,c,a[x].ls); if(r>mid) update2(l,r,c,a[x].rs); pushup(x); } void update1(int x,int y,int val) { while(top[x]!=top[y]){///不在同一条练 if(d[top[x]]<d[top[y]]) ///深度大的也就是最底下的我们先动 swap(x,y); update2(id[top[x]],id[x],val,rt); x=f[top[x]]; } if(id[x]>id[y]) ///现在在同一条链上,那么就找出谁dfs序小的也就是谁在上面谁在下面 { swap(x,y); } update2(id[x],id[y],val,rt); } void dfs1(int u)///进行深度,子树大小,找出重儿子 { sz[u]=1,d[u]=d[f[u]]+1; for(int v,i=head[u];i;i=e[i].nxt){ if((v=e[i].to)!=f[u]){ f[v]=u,dfs1(v),sz[u]+=sz[v]; if(sz[son[u]]<sz[v]) son[u]=v; } } } void dfs2(int u,int tp) { top[u]=tp,id[u]=++cnt,rk[cnt]=u; if(son[u]) dfs2(son[u],tp); for(int v,i=head[u];i;i=e[i].nxt){ if((v=e[i].to)!=f[u]&&v!=son[u]) dfs2(v,v); } } QQ query(int l,int r,int x) { if(l<=a[x].l&&a[x].r<=r) return QQ{a[x].sum,a[x].lcol,a[x].rcol}; pushdown(x); // puts("YES"); int mid=a[x].l+a[x].r>>1; QQ tp1={0,-1,-1},tp2={0,-1,-1},tp; int tot=0; if(l<=mid) tp1=query(l,r,a[x].ls); if(mid<r) tp2=query(l,r,a[x].rs); tp.sum=tp1.sum+tp2.sum-(tp1.rcol==tp2.lcol); tp.lcol=tp1.lcol==-1?tp2.lcol:tp1.lcol; tp.rcol=tp2.rcol==-1?tp1.rcol:tp2.rcol; return tp; } int sum(int x,int y) { int ret=0,lx=-1,ly=-1,ans=0; QQ tp; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]]) swap(x,y),swap(lx,ly); // puts("1!!"); (tp=query(id[top[x]],id[x],rt)); // puts("2!!"); // cout<<tp.sum<<" "<<lx<<" "<<ly<<id[top[x]]<<" "<<id[x]<<endl; ret+=tp.sum-(tp.rcol==lx); lx=tp.lcol; x=f[top[x]]; } if(id[x]<id[y]) swap(x,y),swap(lx,ly); tp=query(id[y],id[x],rt); // cout<<tp.sum<<" "<<lx<<" "<<ly<<endl; ret+=tp.sum-(tp.lcol==ly)-(tp.rcol==lx); return ret; } signed main() { scanf("%lld%lld",&n,&m); r=1; for(int i=1;i<=n;i++)scanf("%lld",&v[i]); for(int x,y,i=1;i<n;i++){ scanf("%lld%lld",&x,&y); add(x,y),add(y,x); } cnt=0;dfs1(r),dfs2(r,r); cnt=0;build(1,n,rt=cnt++); for(int x,y,k,i=1;i<=m;i++){ char ch; cin>>ch; if(ch=='C'){ scanf("%lld%lld%lld",&x,&y,&k); update1(x,y,k); } else{ // puts("YES"); scanf("%lld%lld",&x,&y); printf("%lld\n",sum(x,y)); } } }