2019ICPC上海F A Simple Problem On A Tree(树链剖分)

这道题明显就是告诉你就是树链剖分+线段树维护三次方和,那么显然就是拆项后发现维护一次方和,二次方和和三次方和

这里涉及到两个操作,一个是add一个是mul

因此我们要考虑优先级,这是洛谷的线段树模板2,要先mul再add,因为这样可以解决先加后乘的问题

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5;
const int mod=1e9+7;
int h[N],ne[N],e[N],idx;
int son[N],pre[N],id[N],sz[N],fa[N];
int n;
int depth[N],top[N],times;
ll w[N];
struct node{
    int l,r;
    ll mul;
    ll ad;
    ll sum1;
    ll sum2;
    ll sum3;
}tr[N<<2];
void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u){
    int i;
    sz[u]=1;
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa[u])
            continue;
        fa[j]=u;
        depth[j]=depth[u]+1;
        dfs(j);
        sz[u]+=sz[j];
        if(sz[j]>sz[son[u]]){
            son[u]=j;
        }
    }
}
void dfs1(int u,int x){
    pre[u]=++times;
    id[times]=u;
    top[u]=x;
    if(!son[u])
        return;
    dfs1(son[u],x);
    int i;
    for(i=h[u];i!=-1;i=ne[i]){
        int j=e[i];
        if(j==fa[u]||j==son[u])
            continue;
        dfs1(j,j);
    }
}
void pushup(int u){
    tr[u].sum1=(tr[u<<1].sum1+tr[u<<1|1].sum1)%mod;
    tr[u].sum2=(tr[u<<1].sum2+tr[u<<1|1].sum2)%mod;
    tr[u].sum3=(tr[u<<1].sum3+tr[u<<1|1].sum3)%mod;
}
void build(int u,int l,int r){
    if(l==r){
        tr[u]={l,r,1,0,w[id[l]],w[id[l]]*w[id[l]]%mod,w[id[l]]*w[id[l]]%mod*w[id[l]]%mod};
    }
    else{
        tr[u]={l,r,1,0,0,0,0};
        int mid=l+r>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        pushup(u);
    }
}
void down(int u,ll x,ll y){
    if(y!=1){
        tr[u].sum3=(tr[u].sum3*y%mod*y%mod*y)%mod;
        tr[u].sum2=(tr[u].sum2*y%mod*y)%mod;
        tr[u].sum1=(tr[u].sum1*y%mod)%mod;
        tr[u].mul=tr[u].mul*y%mod;
        tr[u].ad=tr[u].ad*y%mod;
    }
    if(x!=0){
        tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod;
        tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod;
        tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod;
        tr[u].ad=(tr[u].ad+x)%mod;
    }
}
void pushdown(int u){
    ll y=tr[u].mul,x=tr[u].ad;
    down(u<<1,x,y);
    down(u<<1|1,x,y);
    tr[u].mul=1;
    tr[u].ad=0;
}
void modify(int u,int l,int r,ll x,int opt){
    if(tr[u].l>=l&&tr[u].r<=r){
        if(opt==1){
            tr[u].sum1=(tr[u].r-tr[u].l+1)*x%mod;
            tr[u].sum2=(tr[u].r-tr[u].l+1)*x%mod*x%mod;
            tr[u].sum3=(tr[u].r-tr[u].l+1)*x%mod*x%mod*x%mod;
            tr[u].mul=0;
            tr[u].ad=x;
        }
        else if(opt==2){
            tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod;
            tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod;
            tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod;
            tr[u].ad=(tr[u].ad+x)%mod;
        }
        else if(opt==3){
            tr[u].sum3=(tr[u].sum3*x%mod*x%mod*x)%mod;
            tr[u].sum2=(tr[u].sum2*x%mod*x)%mod;
            tr[u].sum1=(tr[u].sum1*x%mod)%mod;
            tr[u].mul=tr[u].mul*x%mod;
            tr[u].ad=(tr[u].ad*x)%mod;
        }
        return ;
    }
    pushdown(u);
    int mid=tr[u].l+tr[u].r>>1;
    if(l<=mid)
        modify(u<<1,l,r,x,opt);
    if(r>mid)
        modify(u<<1|1,l,r,x,opt);
    pushup(u);
}
void change(int x,int y,ll z,int opt){
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]])
            swap(x,y);
        modify(1,pre[top[x]],pre[x],z,opt);
        x=fa[top[x]];
    }
    if(depth[x]>depth[y])
        swap(x,y);
    modify(1,pre[x],pre[y],z,opt);
}
ll query(int u,int l,int r){
    if(tr[u].l>=l&&tr[u].r<=r){
        return tr[u].sum3;
    }
    pushdown(u);
    int mid=tr[u].l+tr[u].r>>1;
    ll ans=0;
    if(l<=mid)
        ans+=query(u<<1,l,r);
    ans%=mod;
    if(r>mid)
        ans=(ans+query(u<<1|1,l,r))%mod;
    return ans;
}
ll qpath(int x,int y){
    ll res=0;
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]])
            swap(x,y);
        res=(res+query(1,pre[top[x]],pre[x]))%mod;
        x=fa[top[x]];
    }
    if(depth[x]>depth[y])
        swap(x,y);
    res=res+query(1,pre[x],pre[y]);
    res%=mod;
    return res;
}
int main(){
    //ios::sync_with_stdio(false);
    int cas=0;
    int t;
    cin>>t;
    while(t--){
        idx=0;
        scanf("%d",&n);
        memset(h,-1,sizeof h);
        memset(sz,0,sizeof sz);
        memset(son,0,sizeof son);
        memset(depth,0,sizeof depth);
        memset(id,0,sizeof id);
        memset(fa,0,sizeof fa);
        memset(top,0,sizeof top);
        times=0;
        int i;
        printf("Case #%d: \n",++cas);
        for(i=1;i<n;i++){
            int a,b;
            scanf("%d%d",&a,&b);
            add(a,b);
            add(b,a);
        }
        for(i=1;i<=n;i++)
            scanf("%lld",&w[i]);
        depth[1]=1;
        fa[1]=0;
        dfs(1);
        dfs1(1,1);
        build(1,1,n);
        int q;
        scanf("%d",&q);
        while(q--){
            int opt;
            scanf("%d",&opt);
            ll u,v,w;
            if(opt==1){
                scanf("%lld%lld%lld",&u,&v,&w);
                change(u,v,w,1);
            }
            else if(opt==2){
                scanf("%lld%lld%lld",&u,&v,&w);
                change(u,v,w,2);
            }
            else if(opt==3){
                scanf("%lld%lld%lld",&u,&v,&w);
                change(u,v,w,3);
            }
            else{
                scanf("%lld%lld",&u,&v);
                printf("%lld\n",qpath(u,v)%mod);
            }
        }
    }
    return 0;
}
View Code

 

posted @ 2020-12-02 21:47  朝暮不思  阅读(129)  评论(0编辑  收藏  举报