2015多校第9场 HDU 5405 Sometimes Naive 树链剖分

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5405

题意: 给你一棵n个节点的树,有点权。

         要求支持两种操作:

    操作1:更改某个节点的权值。

    操作2:给定u,v, 求 Σw[i][j]   i , j 为任意两点且i到j的路径与u到v的路径相交。

 

解法:

  这是一个大树剖题。 

  容易发现对于一个询问,答案为总点权和的平方 减去 去掉u--v这条链后各个子树的点权和的平方的和。 

  开两棵线段树,tag1记录点权和,tag2记录某点的所有轻链子树的点权和的平方的和。

  每次沿着重链往上走时,直接加上这条重链的所有点的tag2和,若有重儿子则直接用tag1计算。由于该条重链必定为其父亲的轻链,故为防止计算重复,还需减去该重链所有点的tag1平方和。

  最后爬到同一颗重链后,还需计算重链上方所有点的贡献。

 

//HDU 5405

//答案为总点权和的平方 减去 去掉u--v这条链后各个子树的点权和的平方的和。
//T1记录点权和,T2记录某点的所有轻链子树的点权和的平方的和
//每次沿着重链往上走时,直接加上这条重链的所有点的tag2和,若有重儿子则直接用tag1计算。
//由于该条重链必定为其父亲的轻链,故为防止计算重复,还需减去该重链所有点的tag1平方和。
//最后爬到同一颗重链后,还需计算重链上方所有点的贡献。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5+5;
const int mod = 1e9+7;
struct Tree{
    LL sum[maxn<<2];
    void build(){
        memset(sum,0,sizeof(sum));
    }
    void pushup(int rt){
        sum[rt] = (sum[rt<<1]+sum[rt<<1|1])%mod;
    }
    void update(int pos, LL v, int l, int r, int rt){
        if(l == r){
            sum[rt] += v;
            sum[rt] %= mod;
            return;
        }
        int mid = (l+r)>>1;
        if(pos <= mid) update(pos, v, l, mid, rt<<1);
        else update(pos, v, mid+1, r, rt<<1|1);
        pushup(rt);
    }
    LL query(int L, int R, int l, int r, int rt){
        if(L<=l&&r<=R){
            return sum[rt];
        }
        int mid = (l+r)/2;
        if(R<=mid) return query(L,R,l,mid,rt<<1)%mod;
        else if(L>mid) return query(L,R,mid+1,r,rt<<1|1)%mod;
        else return (query(L,mid,l,mid,rt<<1)+query(mid+1,R,mid+1,r,rt<<1|1))%mod;
    }
}T1, T2;
int head[maxn],n, m,  edgecnt, dfsclk;
struct edge{
    int to,next;
}E[maxn*2];
int sz[maxn], top[maxn], son[maxn], dep[maxn];
int fa[maxn], tid[maxn], val[maxn];
void init(){
    edgecnt = 0;
    dfsclk = 0;
    memset(head, -1, sizeof(head));
    memset(son, -1, sizeof(son));
}
void addedge(int u, int v){
    E[edgecnt].to = v, E[edgecnt].next = head[u], head[u] = edgecnt++;
}
void dfs1(int u, int father, int d){
    dep[u] = d;
    fa[u] = father;
    sz[u] = 1;
    for(int i = head[u]; ~i; i=E[i].next){
        int v = E[i].to;
        if(v == father) continue;
        dfs1(v, u, d+1);
        sz[u] += sz[v];
        if(son[u] == -1 || sz[v]>sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int tp)
{
    top[u] = tp;
    tid[u] = ++dfsclk;
    if(son[u] == -1) return;
    dfs2(son[u], tp);
    for(int i = head[u]; ~i; i=E[i].next){
        int v = E[i].to;
        if(v!=son[u]&&v!=fa[u])
            dfs2(v,v);
    }
}
inline LL sqr(int x){
    return (LL)x*x;
}
void update(int x, int v){
    int u = top[x];
    while(fa[u]){
        LL sum = T1.query(tid[u], tid[u]+sz[u]-1, 1, n, 1);
        T2.update(tid[fa[u]], ((sqr(val[x]-v)%mod)%mod-(LL)sum*2*(val[x]-v)%mod)%mod, 1, n, 1);
        u = top[fa[u]];
    }
    T1.update(tid[x], v-val[x], 1, n, 1);
    val[x] = v;
}
LL query(int x, int y){
    LL ret = 0;
    while(top[x] != top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ret += T2.query(tid[top[x]], tid[x], 1, n, 1);
        ret %= mod;
        if(son[x]!=-1){
            LL sum = T1.query(tid[son[x]], tid[son[x]]+sz[son[x]]-1, 1, n, 1);
            ret = ret + sum*sum%mod;
            ret %= mod;
        }
        LL sum = T1.query(tid[top[x]], tid[top[x]]+sz[top[x]]-1, 1, n, 1);
        ret = (ret - sum*sum%mod + mod)%mod;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    ret += T2.query(tid[x], tid[y], 1, n, 1);
    ret %= mod;
    if(son[y]!=-1){
        LL sum = T1.query(tid[son[y]], tid[son[y]]+sz[son[y]]-1, 1, n, 1);
        ret = (ret + sum*sum%mod)%mod;
    }
    if(fa[x]){
        LL sum = T1.query(1, n, 1, n, 1) - T1.query(tid[x], tid[x]+sz[x]-1, 1, n, 1);
        ret = (ret+sum*sum%mod)%mod;
    }
    return ret;
}
int main()
{
    while(~scanf("%d %d", &n,&m))
    {
        init();
        T1.build();
        T2.build();
        for(int i=1; i<=n; i++) scanf("%d", &val[i]);
        for(int i=1; i<n; i++){
            int u, v;
            scanf("%d %d", &u,&v);
            addedge(u, v);
            addedge(v, u);
        }
        dfs1(1, 0, 0);
        dfs2(1, 1);
        for(int i=1; i<=n; i++){
            int x = val[i];
            val[i] = 0;
            update(i, x);
        }
        while(m--)
        {
            int op, x, y;
            scanf("%d %d %d", &op,&x,&y);
            if(op == 1){
                update(x, y);
            }
            else{
                LL sum = T1.query(tid[1], tid[1]+sz[1]-1, 1, n, 1);
                sum = sum*sum;
                sum = sum-query(x, y);
                sum = sum%mod;
                if(sum<0) sum+=mod;
                printf("%lld\n", sum);
            }
        }
    }
    return 0;
}

 

posted @ 2017-08-09 20:28  zxycoder  阅读(184)  评论(0编辑  收藏  举报