洛谷3384&bzoj1036树链剖分
值得注意的是:
一个点的子树是存在一起的。。。也就是说我们修改子树的时候只用。。。
/************************************************************** Problem: 1036 User: Ez3real Language: C++ Result: Accepted Time:3068 ms Memory:8112 kb ****************************************************************/ #include<iostream> #include<cstdio> #include<cstdlib> #include<cmath> #include<cstring> #include<algorithm> #define l(x) (x<<1) #define r(x) ((x<<1)|1) #define ll long long using namespace std; struct Treenode{int l,r;ll tag,val,mx;}; struct Segtree { Treenode tr[5*100010]; void pushup(int id){tr[id].val=tr[l(id)].val+tr[r(id)].val;} void pushdown(int id) { tr[l(id)].val+=tr[id].tag*(tr[l(id)].r-tr[l(id)].l+1); tr[r(id)].val+=tr[id].tag*(tr[r(id)].r-tr[r(id)].l+1); tr[l(id)].tag+=tr[id].tag;tr[r(id)].tag+=tr[id].tag; tr[id].tag=0; } void build(int id,int L,int R) { tr[id].l=L,tr[id].r=R; tr[id].tag=0; if(L==R)return; int mid=(L+R)>>1; build(l(id),L,mid); build(r(id),mid+1,R); pushup(id); } void update(int id,int L,int R,ll k) { if(tr[id].l>=L && tr[id].r<=R) { tr[id].val+=k*(tr[id].r-tr[id].l+1); tr[id].tag+=k; return; } if(tr[id].tag)pushdown(id); int mid=(tr[id].l+tr[id].r)>>1; if(L<=mid)update(l(id),L,R,k); if(R>mid) update(r(id),L,R,k); pushup(id); } ll query(int id,int L,int R) { if(tr[id].l>=L && tr[id].r<=R)return tr[id].val; if(tr[id].tag)pushdown(id); int mid=(tr[id].l+tr[id].r)>>1; ll ans=0; if(L<=mid)ans+=query(l(id),L,R); if(R>mid)ans+=query(r(id),L,R); return ans; } }Seg; int first[300100],to[2*300100],next[2*300100]; ll val[2*300100],cnt; int size[300100],depth[300100],fa[300100],pos[300100],bl[300100],sz; inline void add(int u,int v){to[++cnt]=v;next[cnt]=first[u];first[u]=cnt;} int n,q,r,P; inline void dfs1(int u) { size[u]=1; for(int i=first[u];i;i=next[i]) { if(to[i]==fa[u])continue; depth[to[i]]=depth[u]+1; fa[to[i]]=u; dfs1(to[i]); size[u]+=size[to[i]]; } } inline void dfs2(int u,int idc) { int k=0; sz++; pos[u]=sz; bl[u]=idc; for(int i=first[u];i;i=next[i]) if(depth[to[i]]>depth[u] && size[to[i]]>size[k]) k=to[i]; if(k==0)return; dfs2(k,idc); for(int i=first[u];i;i=next[i]) if(depth[to[i]]>depth[u] && k!=to[i]) dfs2(to[i],to[i]); } void tradd(ll x,ll v){Seg.update(1,pos[x],pos[x]+size[x]-1,v);} ll trq(ll x){return Seg.query(1,pos[x],pos[x]+size[x]-1)%P;} void Chainadd(ll x,ll y,ll v) { while(bl[x]!=bl[y]) { if(depth[bl[x]]<depth[bl[y]]){x^=y;y^=x;x^=y;} Seg.update(1,pos[bl[x]],pos[x],v); x=fa[bl[x]]; } if(depth[x]<depth[y]){x^=y;y^=x;x^=y;} Seg.update(1,pos[y],pos[x],v); } ll Qsum(ll u,ll v) { ll sum=0; while(bl[u]!=bl[v]) { if(depth[bl[u]]<depth[bl[v]]){u^=v;v^=u;u^=v;} (sum+=Seg.query(1,pos[bl[u]],pos[u]))%=P; u=fa[bl[u]]; } if(depth[u]<depth[v]){u^=v;v^=u;u^=v;} (sum+=Seg.query(1,pos[v],pos[u]))%=P; return sum; } void Solve() { scanf("%d%d%d%d",&n,&q,&r,&P); for(int i=1;i<=n;i++)scanf("%lld",&val[i]); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } depth[r]=1; dfs1(r); dfs2(r,r); Seg.build(1,1,n); for(int i=1;i<=n;i++)Seg.update(1,pos[i],pos[i],val[i]); for(int i=1;i<=q;i++) { int opt; scanf("%d",&opt); if(opt==1) { ll x,y,w; scanf("%lld%lld%lld",&x,&y,&w); Chainadd(x,y,w); } if(opt==2) { ll x,y; scanf("%lld%lld",&x,&y); printf("%lld\n",Qsum(x,y)); } if(opt==3) { ll x,v; scanf("%lld%lld",&x,&v); tradd(x,v); } if(opt==4) { ll x; scanf("%lld",&x); printf("%lld\n",trq(x)); } } } int main() { Solve(); return 0; }
#include<iostream> #include<cstdio> #include<cstdlib> #include<cmath> #include<cstring> #include<algorithm> #define l(x) (x<<1) #define r(x) ((x<<1)|1) #define ll long long using namespace std; struct Treenode{int l,r;ll tag,val,mx;}; struct Segtree { Treenode tr[5*30010]; void pushup(int id){tr[id].val=tr[l(id)].val+tr[r(id)].val;tr[id].mx=max(tr[l(id)].mx,tr[r(id)].mx);} void build(int id,int L,int R) { tr[id].l=L,tr[id].r=R; if(L==R)return; int mid=(L+R)>>1; build(l(id),L,mid); build(r(id),mid+1,R); pushup(id); } void update(int k,int x,int y) { int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if(l==r){tr[k].val=tr[k].mx=y;return;} if(x<=mid)update(k<<1,x,y); else update(k<<1|1,x,y); tr[k].val=tr[k<<1].val+tr[k<<1|1].val; tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx); } int query(int id,int L,int R) { if(tr[id].l>=L && tr[id].r<=R)return tr[id].val; int mid=(tr[id].l+tr[id].r)>>1; int ans=0; if(L<=mid)ans+=query(l(id),L,R); if(R>mid)ans+=query(r(id),L,R); return ans; } int Query(int id,int L,int R) { if(tr[id].l>=L && tr[id].r<=R)return tr[id].mx; int mid=(tr[id].l+tr[id].r)>>1; int ans=-2147483233; if(L<=mid)ans=max(ans,Query(l(id),L,R)); if(R>mid)ans=max(ans,Query(r(id),L,R)); return ans; } }Seg; int first[30010],to[2*30010],next[2*30010],val[2*30010],cnt; int size[30010],depth[30010],fa[30010],pos[30010],bl[30010],sz; inline void add(int u,int v){to[++cnt]=v;next[cnt]=first[u];first[u]=cnt;} inline void dfs1(int u) { size[u]=1; for(int i=first[u];i;i=next[i]) { if(to[i]==fa[u])continue; depth[to[i]]=depth[u]+1; fa[to[i]]=u; dfs1(to[i]); size[u]+=size[to[i]]; } } inline void dfs2(int u,int idc) { int k=0; sz++; pos[u]=sz; bl[u]=idc; for(int i=first[u];i;i=next[i]) if(depth[to[i]]>depth[u] && size[to[i]]>size[k]) k=to[i]; if(k==0)return; dfs2(k,idc); for(int i=first[u];i;i=next[i]) if(depth[to[i]]>depth[u] && k!=to[i]) dfs2(to[i],to[i]); } int Qsum(int u,int v) { int sum=0; while(bl[u]!=bl[v]) { if(depth[bl[u]]<depth[bl[v]]){u^=v;v^=u;u^=v;} sum+=Seg.query(1,pos[bl[u]],pos[u]); u=fa[bl[u]]; } if(pos[u]>pos[v]){u^=v;v^=u;u^=v;} sum+=Seg.query(1,pos[u],pos[v]); return sum; } int Qmax(int u,int v) { int mx=-2147483233; while(bl[u]!=bl[v]) { if(depth[bl[u]]<depth[bl[v]]){u^=v;v^=u;u^=v;} mx=max(mx,Seg.Query(1,pos[bl[u]],pos[u])); u=fa[bl[u]]; } if(pos[u]>pos[v]){u^=v;v^=u;u^=v;} mx=max(mx,Seg.Query(1,pos[u],pos[v])); return mx; } int n,q; void Solve() { scanf("%d",&n); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } for(int i=1;i<=n;i++)scanf("%d",&val[i]); dfs1(1); dfs2(1,1); Seg.build(1,1,n); for(int i=1;i<=n;i++)Seg.update(1,pos[i],val[i]); scanf("%d",&q); char ch[10]; for(int i=1;i<=q;i++) { int x,y;scanf("%s%d%d",ch,&x,&y); if(ch[0]=='C'){val[x]=y;Seg.update(1,pos[x],y);} else { if(ch[1]=='M')printf("%d\n",Qmax(x,y)); else printf("%d\n",Qsum(x,y)); } } } int main() { Solve(); return 0; }