树链剖分

普通版

#include<iostream>
#include<cstdio>
#include<queue>
#include<stack>
#include<cstring>
#include<algorithm>
#define ll long long
#define rint register int
using namespace std;
inline void read(ll &x)
{
    x=0;ll f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1; c=getchar();}
    while(c<='9'&&c>='0'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    x*=f;
}
inline void print(ll x)
{
    if(x<0){x=-x;putchar('-');}
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
const int SIZE=100100;
ll n,m;
ll cnt;
ll wi[SIZE];
struct edge_{
    ll to,last;
}e[SIZE<<1];
int tot,head[SIZE];
inline void add(int from,int to)
{
    tot++;
    e[tot].to=to;
    e[tot].last=head[from];
    head[from]=tot;
}
ll num[SIZE];//以该节点为根的子树的节点个数
ll son[SIZE];//该节点重儿子
ll dad[SIZE];//该节点父亲
ll dep[SIZE];//深度
ll top[SIZE];//节点所在链的大哥
ll seg[SIZE];//seg[i]储存节点i在线段树中下标
ll rev[SIZE];//rev[i]储存线段树第i个位置的节点编号 
inline void ddfs1(ll nw,ll fa)
{
    num[nw]=1;dad[nw]=fa;
    dep[nw]=dep[fa]+1;
    for(ll i=head[nw];i;i=e[i].last)
    {
        if(e[i].to==fa) continue;
        ddfs1(e[i].to,nw);
        num[nw]+=num[e[i].to];
        if(son[nw]==0||num[son[nw]]<num[e[i].to]) son[nw]=e[i].to;
    }
    return;
}
inline void ddfs2(ll nw,ll tp)
{
    top[nw]=tp;
    seg[nw]=++cnt;
    rev[seg[nw]]=nw;
    if(son[nw]==0) return;
    ddfs2(son[nw],tp);
    for(ll i=head[nw];i;i=e[i].last)
    {
        if(e[i].to==dad[nw]||e[i].to==son[nw]) continue;
        ddfs2(e[i].to,e[i].to);
    }
    return;
}
struct segtree{
    ll l,r,sum,mid,tag;
}st[SIZE<<2];
inline void pushup(ll rt) {st[rt].sum=st[rt<<1].sum+st[rt<<1|1].sum;}
inline void pushdown(ll rt)
{
    if(st[rt].tag)
    {
        st[rt<<1].tag+=st[rt].tag;
        st[rt<<1|1].tag+=st[rt].tag;
        st[rt<<1].sum+=st[rt].tag*(st[rt<<1].r-st[rt<<1].l+1);
        st[rt<<1|1].sum+=st[rt].tag*(st[rt<<1|1].r-st[rt<<1|1].l+1);
        st[rt].tag=0;
    }
    return;
}
inline void build(ll l,ll r,ll rt)
{
    st[rt].l=l;
    st[rt].r=r;
    st[rt].mid=(st[rt].l+st[rt].r)>>1;
    if(l==r){st[rt].sum=wi[rev[l]];return;}
    build(st[rt].l,st[rt].mid,rt<<1);
    build(st[rt].mid+1,st[rt].r,rt<<1|1);
    pushup(rt);
} 
inline void update(ll rt,ll pos,ll val)
{
    if(pos==st[rt].l&&st[rt].r==pos)
    {
        st[rt].sum+=val;
        return;
    }
    pushdown(rt);
    if(pos<=st[rt].mid) update(rt<<1,pos,val);
    else update(rt<<1|1,pos,val);
    pushup(rt);
}
inline void updat(ll rt,ll L,ll R,ll val)
{
    if(L<=st[rt].l&&R>=st[rt].r)
    {
        st[rt].tag+=val;
        st[rt].sum+=val*(st[rt].r-st[rt].l+1);
        return;
    }
    pushdown(rt);
    if(L<=st[rt].mid) updat(rt<<1,L,R,val);
    if(R>st[rt].mid) updat(rt<<1|1,L,R,val);
    pushup(rt);
}
inline ll qsum(ll rt,ll L,ll R)
{
    ll ans=0;
    if(L<=st[rt].l&&R>=st[rt].r) return st[rt].sum;
    pushdown(rt);
    if(L<=st[rt].mid) ans+=qsum(rt<<1,L,R);
    if(R>st[rt].mid) ans+=qsum(rt<<1|1,L,R);
    return ans;
}
inline ll qpath(ll x,ll y)
{
    ll fx=top[x],fy=top[y],ans=0;
    while(fx!=fy)
    {
        if(dep[fx]<dep[fy]) swap(fx,fy),swap(x,y);
        ans+=qsum(1,seg[fx],seg[x]);
        x=dad[fx],fx=top[x];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=qsum(1,seg[x],seg[y]);
    return ans;
}
int main()
{
//  freopen("1.in","r",stdin);
//  freopen("my.out","w",stdout);
    read(n);read(m);
    for(ll i=1;i<=n;i++) read(wi[i]);
    for(ll i=2;i<=n;i++)
    {
        ll a,b;
        read(a);read(b);
        add(a,b);add(b,a);
    }
    ddfs1(1,0);ddfs2(1,1);
    build(1,cnt,1);
    for(ll i=1;i<=m;i++)
    {
        ll a,b,c;
        read(c);read(b);
        if(c==3) printf("%lld\n",qpath(1,b));
        else
        {
            read(a);
            if(c==2) updat(1,seg[b],seg[b]+num[b]-1,a);
            else if(c==1) update(1,seg[b],a);
        }
    }
}

要求换根

#include<iostream>
#include<cstdio>
#include<cmath>
#include<queue>
#include<cstring>
#include<algorithm>
#define lson x<<1
#define rson x<<1|1
#define mid ((st[x].l+st[x].r)>>1)
#define ll long long
#define rint register int
#define mp(x,y) make_pair(x,y)
using namespace std;
template<typename xxx>void read(xxx &x)
{
    x=0;int f=1;char c=getchar();
    for(;c<'0'||c>'9';c=getchar()) if(c=='-') f=-1;
    for(;c>='0'&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
    x*=f;
}
template<typename xxx>void print(xxx x)
{
    if(x<0){putchar('-');x=-x;}
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
//const int mod=1e9+7;
const int maxn=200080;
const int inf=0x7fffffff;
struct edge{
    int to,last;
}e[maxn<<1];
int tot,head[maxn];
inline void add(int from,int to)
{
    ++tot;
    e[tot].to=to;
    e[tot].last=head[from];
    head[from]=tot;
} 
int n,m,t,rt;
int w[maxn],cnt;
int siz[maxn];
int son[maxn];
int dad[maxn];
int dep[maxn];
int top[maxn];
int seg[maxn];
int rev[maxn];
inline void ddfs1(int cur,int fa)
{
    siz[cur]=1;dad[cur]=fa;
    dep[cur]=dep[fa]+1;
    for(rint i=head[cur];i;i=e[i].last)
    {
        if(e[i].to==fa) continue;
        ddfs1(e[i].to,cur);
        siz[cur]+=siz[e[i].to];
        if(son[cur]==0 || siz[son[cur]]<siz[e[i].to]) son[cur]=e[i].to;
    }
    return ;
}
inline void ddfs2(int cur,int tp)
{
    top[cur]=tp;
    seg[cur]=++cnt;//原树上节点i在线段树中编号 
    rev[cnt]=cur;//
    if(son[cur]==0) return ;
    ddfs2(son[cur],tp);
    for(rint i=head[cur];i;i=e[i].last)
    {
        if(e[i].to==dad[cur] || e[i].to==son[cur]) continue;
        ddfs2(e[i].to,e[i].to);
    }
    return ;
}
struct segtree{
    int l,r,tag;ll sum;
}st[maxn<<2];
inline void pushup(int x)
{
    st[x].sum=st[lson].sum+st[rson].sum;
}
inline void pushdown(int x)
{
    if(st[x].tag)
    {
        st[lson].tag+=st[x].tag;        
        st[rson].tag+=st[x].tag;        
        st[lson].sum+=st[x].tag*(st[lson].r-st[lson].l+1);
        st[rson].sum+=st[x].tag*(st[rson].r-st[rson].l+1);
        st[x].tag=0;
    }
}
inline void build(int x,int l,int r)
{
    st[x].l=l;st[x].r=r;st[x].sum=st[x].tag=0;
    if(l==r) 
    {
        st[x].sum=w[rev[l]];
        return ;
    }
    build(lson,l,mid);
    build(rson,mid+1,r);
    pushup(x);
}
inline void change(int x,int l,int r,ll val)
{
    if(l<=st[x].l && st[x].r<=r)
    {
        st[x].sum+=val*(st[x].r-st[x].l+1);
        st[x].tag+=val;
        return ;
    }
    pushdown(x);
    if(l<=mid) change(lson,l,r,val);
    if(r>mid) change(rson,l,r,val);
    pushup(x);
}
inline ll query(int x,int l,int r)
{
    if(l<=st[x].l && st[x].r<=r) return st[x].sum;
    pushdown(x);ll ans=0;
    if(l<=mid) ans+=query(lson,l,r); 
    if(r>mid) ans+=query(rson,l,r);
    return ans; 
}
inline void q1(int x,int y,ll val)
{
    while(top[x]^top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        change(1,seg[top[x]],seg[x],val);
        x=dad[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    change(1,seg[x],seg[y],val);
}
inline ll q2(int x,int y)
{
    ll ans=0;
    while(top[x]^top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans+=query(1,seg[top[x]],seg[x]);
        x=dad[top[x]]; 
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=query(1,seg[x],seg[y]);
    return ans;
}
inline int lca(int x,int y)
{
    while(dad[top[x]]^y && top[x]^top[y]) x=dad[top[x]];
    return dad[top[x]]==y?top[x]:son[y];
}
int main()
{
    read(n);rt=1;
    for(rint i=1;i<=n;++i) read(w[i]);
    for(rint i=2;i<=n;++i) 
    {
        int tp;read(tp);
        add(i,tp);add(tp,i);
    }
    ddfs1(1,0);
    ddfs2(1,1);
    build(1,1,n);
    read(m);
    for(rint i=1;i<=m;++i)
    {
        ll opt,u,v,k;
        read(opt);read(u);
        if(opt==1)
        {
            rt=u;
        }
        else if(opt==2)
        {
            read(v);read(k);
            q1(u,v,k);
        }
        else if(opt==3)
        {
            read(k);
            if(u==rt) change(1,1,n,k);
            else if(seg[rt]<seg[u] || seg[rt]>=seg[u]+siz[u]) change(1,seg[u],seg[u]+siz[u]-1,k);       
            else
            {
                change(1,1,n,k);
                v=lca(rt,u);
                change(1,seg[v],seg[v]+siz[v]-1,-k);
            }
        }
        else if(opt==4)
        {
            read(v);
            print(q2(u,v));
            putchar('\n');
        }
        else if(opt==5)
        {
            if(u==rt) print(query(1,1,n));
            else if(seg[rt]<seg[u] || seg[rt]>=seg[u]+siz[u]) print(query(1,seg[u],seg[u]+siz[u]-1));
            else
            {
                ll ans=query(1,1,n);
                v=lca(rt,u);
                ans-=query(1,seg[v],seg[v]+siz[v]-1);
                print(ans);
            }
            putchar('\n');
        }
    }
    return 0;
}
/*
请根据数据体会
8 7
1 2
1 3
1 4
2 5
5 6
5 7
4 8
根换为5时求1的子树和
v求得为2
再用全部减去2的子树和即为答案 
*/ 
posted @ 2019-11-14 21:11  Thomastine  阅读(123)  评论(0编辑  收藏  举报