【题解】P3676 小清新数据结构题
思路
树剖。
首先设 \(S_u\) 表示 \(u\) 子树点权和的平方和,\(S ^ {\prime}_u\) 表示换根后 \(u\) 子树点权和的平方和,\(ans_i\) 表示以 \(i\) 为根时的答案,所有点权和为 \(T\).
根据换根的性质可知,每次换根时贡献受到影响的结点都在旧根和新根的路径上。设这条长度为 \(m\) 的路径为 \(p\).
每次询问的答案实际上是 \(ans_1 - \sum\limits_{i = 1} ^ m S_{p_i} ^ 2 + \sum\limits_{i = 1} ^ m (S ^ {\prime}_{p_i})^2\)
因为 \(S_{p_1} = S^{\prime}_{p_m} = T\),可以消掉:
\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 1}^{m - 1} (S^{\prime}_{p_i})^2\)
又根据换根的性质有 \(S^{\prime}_{p_i} + S_{p_{i + 1}} = T\),即 \(S^{\prime}_{p_i} = T - S_{p_{i + 1}}\)
代入原式得:
\(ans_1 - \sum\limits_{i = 2}^m S_{p_i} ^ 2 + \sum\limits_{i = 1} ^ {m - 1} (T - S_{p_{i + 1}}) ^ 2\)
也就是:
\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 2}^{m} (T - S_{p_i}) ^ 2\)
展开平方得:
\(ans_1 - \sum\limits_{i = 2}^m S_{p_i}^2 + \sum\limits_{i = 2}^{m} T^2 - 2 T S_{p_i} + S_{p_i}^2\)
平方项 \(S_{p_i}^2\) 消掉得:
\(ans_1 + \sum\limits_{i = 2}^{m} T^2 - 2 T S_{p_i}\)
也就是:
\(ans_1 + (m - 1) T ^ 2 - 2 T \sum\limits_{i = 2} ^ m S_{p_i}\)
重新加入 \(1\) 的贡献得:
\(ans_1 + (m + 1) T ^ 2 - 2 T \sum\limits_{i = 1} ^ m S_{p_i}\)
直接上树剖维护路径和。
修改的时候在线段树上顺便维护一下就行。
时间复杂度 \(O(n \log^2 n)\)
代码
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 5;
const int sgt_sz = maxn << 2;
int n, q, cnt;
int head[maxn], fa[maxn], son[maxn], top[maxn];
int dep[maxn], sz[maxn], pos[maxn], nd[maxn], w[maxn];
ll ans1, tsum, ws[maxn];
vector<int> g[maxn];
namespace SGT
{
#define ls (k << 1)
#define rs (k << 1 | 1)
ll sum[sgt_sz], lazy[sgt_sz];
void push_up(int k) { sum[k] = sum[ls] + sum[rs]; }
void push_down(int k, int l, int r)
{
if (!lazy[k]) return;
int mid = (l + r) >> 1;
sum[ls] += (mid - l + 1) * lazy[k];
sum[rs] += (r - mid) * lazy[k];
lazy[ls] += lazy[k], lazy[rs] += lazy[k];
lazy[k] = 0ll;
}
void build(int k, int l, int r)
{
if (l == r)
{
sum[k] = ws[nd[l]];
ans1 += sum[k] * sum[k];
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(k);
}
void update(int k, int l, int r, int ql, int qr, int w)
{
if ((l >= ql) && (r <= qr))
{
// printf("modify %lld -> ", ans1);
ans1 += (2ll * sum[k] + (r - l + 1) * w) * w;
// printf("%lld\n", w);
sum[k] += (r - l + 1) * w, lazy[k] += w;
return;
}
push_down(k, l, r);
int mid = (l + r) >> 1;
if (ql <= mid) update(ls, l, mid, ql, qr, w);
if (qr > mid) update(rs, mid + 1, r, ql, qr, w);
push_up(k);
}
ll query(int k, int l, int r, int ql, int qr)
{
if ((l >= ql) && (r <= qr)) return sum[k];
push_down(k, l, r);
int mid = (l + r) >> 1; ll res = 0;
if (ql <= mid) res += query(ls, l, mid, ql, qr);
if (qr > mid) res += query(rs, mid + 1, r, ql, qr);
return res;
}
void modify(int u, int w)
{
while (u)
{
update(1, 1, n, pos[top[u]], pos[u], w);
u = fa[top[u]];
}
}
ll qry(int u)
{
ll res = 0;
while (u)
{
res += query(1, 1, n, pos[top[u]], pos[u]);
u = fa[top[u]];
}
return res;
}
}
inline int read()
{
int res = 0, flag = 1;
char ch = getchar();
while ((ch < '0') || (ch > '9'))
{
if (ch == '-') flag = -1;
ch = getchar();
}
while ((ch >= '0') && (ch <= '9'))
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res * flag;
}
inline void write(ll x)
{
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
void dfs1(int u, int f)
{
fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
for (int v : g[u])
{
if (v == f) continue;
dfs1(v, u);
sz[u] += sz[v], ws[u] += ws[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int t)
{
top[u] = t, pos[u] = ++cnt, nd[cnt] = u;
if (son[u]) dfs2(son[u], t);
for (int v : g[u])
{
if ((v == fa[u]) || (v == son[u])) continue;
dfs2(v, v);
}
}
int main()
{
// freopen("P3676_1.in", "r", stdin);
// freopen("P3676_1.res", "w", stdout);
n = read(), q = read();
for (int i = 1, u, v; i <= n - 1; i++)
{
u = read(), v = read();
g[u].push_back(v), g[v].push_back(u);
}
for (int i = 1; i <= n; i++) w[i] = ws[i] = read(), tsum += w[i];
dfs1(1, 0), dfs2(1, 1), SGT::build(1, 1, n);
while (q--)
{
int opt, x, y;
opt = read();
if (opt == 1)
{
x = read(), y = read();
tsum += (y - w[x]);
SGT::modify(x, y - w[x]);
w[x] = y;
}
else
{
x = read();
ll ans = ans1 + 1ll * (dep[x] + 1) * tsum * tsum;
// printf("debug %lld %lld\n", ans1, SGT::qry(x) * 2ll * tsum);
printf("%lld\n", ans - SGT::qry(x) * 2ll * tsum);
// write(SGT::qry(x)), putchar('\n');
}
}
return 0;
}