Luogu 3676 小清新数据结构题

推荐博客: http://www.cnblogs.com/Mychael/p/9257242.html

感觉还挺好玩的

首先考虑以1为根,把每一个点子树的权值和都算出来,记为$val_{i}$,那么在所有操作都没有开始的时候(以$1$为根的)$ans_{1} = \sum_{i= 1}^{n}val_{i}^{2}$

考虑到一个修改的操作只会对修改的点$x$到根($1$)链上的点产生影响,那么一次修改只要修对这条树链上的点增加$v - a_{x}$(假设修改后的值为$v$)就好了。

链剖之后线段树维护一下$val_{i}$,区间修改就很简单。

然后考虑换根:

我们发现当以$x$为根的时候,$x$原来的子树显然不会受到影响,而变化了的是原来的根$1$到$x$的链上的点,不妨设有$k$个结点,换根前(以$1$为根)的每个结点子树$val$值和为$a_{i}$,换根后(以$x$为根)的每个结点子树$val$值和为$b_{i}$

有一条显然的性质:$a_{i + 1} + b_{i} = a_{1} = b_{k}$都等于原来全部结点的$val$值和

那么换根之后的答案  $ans_{x} = ans_{1} - \sum_{i = 1}^{k}a_{i}^{2} + \sum_{i = 1}^{k}b_{i}^{2}$

代入上面的那条性质消掉$b$,发现$ans_{x} = ans_{1} + (k - 1)a_{1}^{2} - 2a_{1}\sum_{i = 2}^{k}a_{i}$

设$s_{i}$表示$i$的子树中所有$val$值和,那么$ans_{x} = ans_{1} + s_{1}((k + 1) s_{1} - 2\sum_{i = 1}^{k}s_{i})$。

容易发现这个$k$即为$dep_{x}$,而这个$\sum_{i = 1}^{k}s_{i}$ 和 $s_{1}$显然可以用线段树维护出来

考虑一下, 一次修改还会对$ans_{1}$产生影响,$ans_{1} += \sum_{i = 1}^{tot}(val_{i}+ \Delta v)^{2} - \sum_{i = 1}^{tot}val_{i}^{2} = tot\Delta v^{2} + 2\Delta v\sum_{i = 1}^{tot}val_{i}$。

因为每次发生变化的只有一条树链上的点,所以$tot = dep_{x}$,这个原来的$\sum_{i = 1}^{tot}val_{i}$可以在跳轻重链的过程中算出来。

时间复杂度$O(nlog^{2}n)$。

Code:

#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;

const int N = 2e5 + 5;

int n, qn, dfsc = 0, dep[N], siz[N], id[N];
int tot = 0, head[N], top[N], fa[N], son[N];
ll a[N], ans = 0LL, nowSum = 0LL, w[N], val[N];

struct Edge {
    int to, nxt;
} e[N << 1];

inline void add(int from, int to) {
    e[++tot].to = to;
    e[tot].nxt = head[from];
    head[from] = tot;
}

template <typename T>
inline void read(T &X) {
    X = 0;
    char ch = 0;
    T op = 1;
    for(; ch > '9'|| ch < '0'; ch = getchar())
        if(ch == '-') op = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

void dfs1(int x, int fat, int depth) {
    siz[x] = 1, fa[x] = fat, dep[x] = depth, val[x] = a[x];
    int maxson = -1;
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fat) continue;
        dfs1(y, x, depth + 1);
        siz[x] += siz[y], val[x] += val[y]; 
        if(siz[y] > maxson) 
            maxson = siz[y], son[x] = y;
    }
}

void dfs2(int x, int topf) {
    w[id[x] = ++dfsc] = val[x], top[x] = topf;
    if(!son[x]) return;
    dfs2(son[x], topf);
    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fa[x] || y == son[x]) continue;
        dfs2(y, y);
    }
}

namespace SegT {
    ll s[N << 2], tag[N << 2];
    
    #define lc p << 1
    #define rc p << 1 | 1
    #define mid ((l + r) >> 1)
    
    inline void up(int p) {
        if(p) s[p] = s[lc] + s[rc];
    }
    
    inline void down(int p, int l, int r) {
        if(!tag[p]) return;        
        s[lc] += 1LL * (mid - l + 1) * tag[p];
        s[rc] += 1LL * (r - mid) * tag[p];
        tag[lc] += tag[p], tag[rc] += tag[p];
        tag[p] = 0LL;
    }
    
    void build(int p, int l, int r) {
        tag[p] = 0LL;
        if(l == r) {
            s[p] = w[l];
            return;
        }
        
        build(lc, l, mid);
        build(rc, mid + 1, r);
        up(p);
    }
    
    void modify(int p, int l, int r, int x, int y, ll v) {
        if(x <= l && y >= r) {
            s[p] += 1LL * (r - l + 1) * v;
            tag[p] += v;
            return;
        }
        
        down(p, l, r);
        if(x <= mid) modify(lc, l, mid, x, y, v);
        if(y > mid) modify(rc, mid + 1, r, x, y, v);
        up(p);
    }
    
    ll qSum(int p, int l, int r, int x, int y) {
        if(x <= l && y >= r) return s[p];
        down(p, l, r);
        
        ll res = 0LL;
        if(x <= mid) res += qSum(lc, l, mid, x, y);
        if(y > mid) res += qSum(rc, mid + 1, r, x, y);
        return res;
    }
    
} using namespace SegT;

inline void mTree(int x) {
    ll v, sum = 0LL, len = (ll)dep[x]; read(v);
    v -= a[x], a[x] += v;
    for(; x != 0; x = fa[top[x]]) {
        sum += qSum(1, 1, n, id[top[x]], id[x]);
        modify(1, 1, n, id[top[x]], id[x], v);
    } 
    ans += 2LL * v * sum + 1LL * v * v * len;
    nowSum += v;
}

inline ll qTree(int x) {
    ll res = 0LL;
    for(; x != 0; x = fa[top[x]]) 
        res += qSum(1, 1, n, id[top[x]], id[x]);
    return res;
}

inline void solve(int x) {
    ll k = (ll)dep[x], sum = qTree(x);
    printf("%lld\n", ans + nowSum * ((k + 1) * nowSum - 2 * sum));
}

int main() {
    read(n), read(qn);
    for(int x, y, i = 1; i < n; i++) {
        read(x), read(y);
        add(x, y), add(y, x);
    }
    for(int i = 1; i <= n; i++) read(a[i]);
    
    dfs1(1, 0, 1);
    dfs2(1, 1);
    build(1, 1, n);
    
/*    for(int i = 1; i <= n; i++)
        printf("%d ", dep[i]);
    printf("\n");
    for(int i = 1; i <= n; i++)
        printf("%d ", top[i]);
    printf("\n");   
    for(int i = 1; i <= n; i++)
        printf("%d ", w[i]);
    printf("\n");   */
    
    for(int i = 1; i <= n; i++)    {
        nowSum += a[i];
        ans += val[i] * val[i];
    }
//    printf("%lld\n", ans);
    
    for(int op, x; qn--; ) {
        read(op), read(x);
        if(op == 1) mTree(x);
        else solve(x);
    }
    
    return 0;
}
View Code

 

posted @ 2018-08-24 12:01  CzxingcHen  阅读(195)  评论(0编辑  收藏  举报