Loading

【题解】P3676 小清新数据结构题

思路

树剖。

首先设 \(S_u\) 表示 \(u\) 子树点权和的平方和,\(S ^ {\prime}_u\) 表示换根后 \(u\) 子树点权和的平方和,\(ans_i\) 表示以 \(i\) 为根时的答案,所有点权和为 \(T\).

根据换根的性质可知,每次换根时贡献受到影响的结点都在旧根和新根的路径上。设这条长度为 \(m\) 的路径为 \(p\).

每次询问的答案实际上是 \(ans_1 - \sum\limits_{i = 1} ^ m S_{p_i} ^ 2 + \sum\limits_{i = 1} ^ m (S ^ {\prime}_{p_i})^2\)

因为 \(S_{p_1} = S^{\prime}_{p_m} = T\),可以消掉:

\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 1}^{m - 1} (S^{\prime}_{p_i})^2\)

又根据换根的性质有 \(S^{\prime}_{p_i} + S_{p_{i + 1}} = T\),即 \(S^{\prime}_{p_i} = T - S_{p_{i + 1}}\)

代入原式得:

\(ans_1 - \sum\limits_{i = 2}^m S_{p_i} ^ 2 + \sum\limits_{i = 1} ^ {m - 1} (T - S_{p_{i + 1}}) ^ 2\)

也就是:

\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 2}^{m} (T - S_{p_i}) ^ 2\)

展开平方得:

\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 2}^{m} T^2 - 2 T S_{p_i} + S_{p_i}^2\)

平方项 \(S_{p_i}^2\) 消掉得:

\(ans_1 + \sum\limits_{i = 2}^{m} T^2 - 2 T S_{p_i}\)

也就是:

\(ans_1 + (m - 1) T ^ 2 - 2 T \sum\limits_{i = 2} ^ m S_{p_i}\)

重新加入 \(1\) 的贡献得:

\(ans_1 + (m + 1) T ^ 2 - 2 T \sum\limits_{i = 1} ^ m S_{p_i}\)

直接上树剖维护路径和。

修改的时候在线段树上顺便维护一下就行。

时间复杂度 \(O(n \log^2 n)\)

代码

#include <cstdio>
#include <vector>
using namespace std;

typedef long long ll;

const int maxn = 2e5 + 5;
const int sgt_sz = maxn << 2;

int n, q, cnt;
int head[maxn], fa[maxn], son[maxn], top[maxn];
int dep[maxn], sz[maxn], pos[maxn], nd[maxn], w[maxn];
ll ans1, tsum, ws[maxn];
vector<int> g[maxn];

namespace SGT
{
    #define ls (k << 1)
    #define rs (k << 1 | 1)

    ll sum[sgt_sz], lazy[sgt_sz];

    void push_up(int k) { sum[k] = sum[ls] + sum[rs]; }

    void push_down(int k, int l, int r)
    {
        if (!lazy[k]) return;
        int mid = (l + r) >> 1;
        sum[ls] += (mid - l + 1) * lazy[k];
        sum[rs] += (r - mid) * lazy[k];
        lazy[ls] += lazy[k], lazy[rs] += lazy[k];
        lazy[k] = 0ll;
    }

    void build(int k, int l, int r)
    {
        if (l == r)
        {
            sum[k] = ws[nd[l]];
            ans1 += sum[k] * sum[k];
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid);
        build(rs, mid + 1, r);
        push_up(k);
    }

    void update(int k, int l, int r, int ql, int qr, int w)
    {
        if ((l >= ql) && (r <= qr))
        {
            // printf("modify %lld -> ", ans1);
            ans1 += (2ll * sum[k] + (r - l + 1) * w) * w;
            // printf("%lld\n", w);
            sum[k] += (r - l + 1) * w, lazy[k] += w;
            return;
        }
        push_down(k, l, r);
        int mid = (l + r) >> 1;
        if (ql <= mid) update(ls, l, mid, ql, qr, w);
        if (qr > mid) update(rs, mid + 1, r, ql, qr, w);
        push_up(k);
    }

    ll query(int k, int l, int r, int ql, int qr)
    {
        if ((l >= ql) && (r <= qr)) return sum[k];
        push_down(k, l, r);
        int mid = (l + r) >> 1; ll res = 0;
        if (ql <= mid) res += query(ls, l, mid, ql, qr);
        if (qr > mid) res += query(rs, mid + 1, r, ql, qr);
        return res;
    }

    void modify(int u, int w)
    {
        while (u)
        {
            update(1, 1, n, pos[top[u]], pos[u], w);
            u = fa[top[u]];
        }
    }

    ll qry(int u)
    {
        ll res = 0;
        while (u)
        {
            res += query(1, 1, n, pos[top[u]], pos[u]);
            u = fa[top[u]];
        }
        return res;
    }
}

inline int read()
{
    int res = 0, flag = 1;
    char ch = getchar();
    while ((ch < '0') || (ch > '9'))
    {
        if (ch == '-') flag = -1;
        ch = getchar();
    }
    while ((ch >= '0') && (ch <= '9'))
    {
        res = res * 10 + ch - '0';
        ch = getchar();
    }
    return res * flag;
}

inline void write(ll x)
{
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

void dfs1(int u, int f)
{
    fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
    for (int v : g[u])
    {
        if (v == f) continue;
        dfs1(v, u);
        sz[u] += sz[v], ws[u] += ws[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int t)
{
    top[u] = t, pos[u] = ++cnt, nd[cnt] = u;
    if (son[u]) dfs2(son[u], t);
    for (int v : g[u])
    {
        if ((v == fa[u]) || (v == son[u])) continue;
        dfs2(v, v);
    }
}

int main()
{
    // freopen("P3676_1.in", "r", stdin);
    // freopen("P3676_1.res", "w", stdout);
    n = read(), q = read();
    for (int i = 1, u, v; i <= n - 1; i++)
    {
        u = read(), v = read();
        g[u].push_back(v), g[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) w[i] = ws[i] = read(), tsum += w[i];
    dfs1(1, 0), dfs2(1, 1), SGT::build(1, 1, n);
    while (q--)
    {
        int opt, x, y;
        opt = read();
        if (opt == 1)
        {
            x = read(), y = read();
            tsum += (y - w[x]);
            SGT::modify(x, y - w[x]);
            w[x] = y;
        }
        else
        {
            x = read();
            ll ans = ans1 + 1ll * (dep[x] + 1) * tsum * tsum;
            // printf("debug %lld %lld\n", ans1, SGT::qry(x) * 2ll * tsum);
            printf("%lld\n", ans - SGT::qry(x) * 2ll * tsum);
            // write(SGT::qry(x)), putchar('\n');
        }
    }
    return 0;
}
posted @ 2023-01-14 21:41  kymru  阅读(25)  评论(0编辑  收藏  举报