BZOJ3091 城市旅行

Description

给一颗以 \(1\) 为根的有根树,维护以下操作

  1. 连接 \((u,v)\) 这条边
  2. 删除 \((u,v)\) 这条边
  3. \(u\)\(v\) 的链上每个点加上一个数
  4. 求在 \((u,v)\) 上任意选两个点它们之间的权值和的期望

\(n, m \leq 50000, a_i \leq 10^6\)

Solution

前三个操作就是 LCT 板子,考虑如何在 LCT 上维护 4 操作

为了方便,设这个路径是 \(a_1, a_2, a_3, \cdots, a_{siz}\) ,其中 \(siz\) 是长度

考虑每个点的贡献,易得我们要求的期望值 \(=\frac{\sum\limits_{i=1}^{siz} i (siz - i + 1)a_i}{\frac{siz(siz+1)}{2}}\)

显然这个分母很好搞,只需要考虑怎么在 LCT 上维护分子,或者说在平衡树上。

也就是说,如果知道左子和右子的答案如何更新出这个点的答案

设左子表示 \(a_1, a_2, \cdots, a_p\), 该点的值是 \(a_{p+1}\) ,右子表示 \(a_{p+2}, \cdots, a_{siz}\)

可以得到:左子的 \(siz_0 = p\),右子的 \(siz_1 = siz - p - 1\)

改点要的答案减去左子的答案减去右子的答案便是

\(\sum\limits_{i=1}^{siz}i(siz - i + 1)a_i - \sum\limits_{i=1}^{p}i(p-i+1)a_i-\sum\limits_{i=p+2}^{siz} (i-p-1)(siz - i + 1)a_i\)

\(=\sum\limits_{i=1}^{p} i(siz-p)a_i+a_{p+1}(p+1)(siz-p)+\sum\limits_{i=p+2}^{siz}(p+1)(siz-i+1)a_i\)

根据上面得到的 \(siz_0=p,siz_1=siz-p-1\) 简单化简一下可以得到

\(=(siz_1+1)\sum\limits_{i=1}^{siz_0}i\cdot a_i+a_{siz_0+1}(siz_0+1)(siz_1+1)+(siz_0+1)\sum\limits_{i=p+2}^{siz}(siz - i +1)a_i\)

到这里应该你已经知道怎么做了..

为了清楚,再令

\(b_1, b_2, \cdots,b_{siz_b}\) 是左子的, \(c_1, c_2, \cdots,c_{siz_c}\) 是右子的,\(d\) 是这个点本身的值。那么可以化简成简单清楚对称的形式:

\(=(siz_c+1)\sum\limits_{i=1}^{siz_b}i\cdot b_i+d(siz_b+1)(siz_c+1)+(siz_b+1)\sum\limits_{i=1}^{siz_c}(siz_c-i+1)c_i\)

你只需要每个点再维护两个值:

\(ls=\sum\limits_{i=1}^{siz}i\cdot a_i\)\(rs=\sum\limits_{i=1}^{siz}(siz - i +1)a_i\)

就可以从左右两个儿子得到自己的值

这两个东西维护还是比较简单的..具体的话就是再维护一个 \(s\) 为子树里所有数的和然后令 \(b,c\) 是左右两个儿子,那么有

\(ls = ls_b+d\cdot(siz_b+1)+ls_c+s_c (siz_b+1)\)

\(rs=rs_c+d\cdot(siz_c+1)+rs_b+s_b(siz_c+1)\)

就这样维护

以上是如何用左右儿子的信息得到自己,再来考虑链加的问题

一条链加上一个数 \(x\) ,那么会如何影响我们维护的值?

  • 对于 \(s\)\(s = s + siz\cdot x\)
  • 对于 \(ls\)\(ls = ls + \sum\limits_{i=1}^{siz}i \cdot x = ls + \frac{siz(siz+1)}{2}\cdot x\)
  • 对于 \(rs\):和 ls 一样 \(rs = rs+\frac{siz(siz+1)}{2}\cdot x\)
  • 对于最后的答案 \(S\)\(S = S + \sum\limits_{i=1}^{siz} i \cdot (siz - i +1)\cdot x\) 通过简单计算可得 \(S= S+\frac{siz(siz+1)(siz+2)}{6}\cdot x\)
  • 对于自己的值:直接加上 \(x\) (废话)

然后 LCT 板子套一套就做完了

注意事项:

  • 翻转的时候需要 swap(ls, rs)
  • 两个点之间是联通的时候才执行链加操作(坑死我了)

Code

/**
 * Author: AcFunction
 * Date:   2019-02-17 11:17:08
 * Email:  3486942970@qq.com
**/

#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int N = 200200; 
const ll INF = (ll)1e18; 

int n, m;
ll a[N]; 
struct node {

  int rev; 
  ll d, s, ls, rs, s1, add, siz; 
  node *ch[2], *prt;

  int isr() { 
    return (!prt) || ( prt->ch[0] != (this) && prt->ch[1] != (this) ); 
  }
  
  int dir() { 
    return prt->ch[1] == (this); 
  }
  void setc(node *p, int k) {
    (this)->ch[k] = p; 
    if(p) p->prt = (this); 
  }
  
  void setr() {
    rev ^= 1; 
    swap(ls, rs);
    swap(ch[0], ch[1]); 
  }

  void seta(ll x) {
    d += x, add += x; s += siz * x; 
    ls += siz * (siz + 1) / 2 * x; 
    rs += siz * (siz + 1) / 2 * x; 
    s1 += siz * (siz + 1) * (siz + 2) / 6 * x; 
  }

  void upd() { 
    siz = 1, s = d; 
    if(ch[0]) siz += ch[0]->siz, s += ch[0]->s; 
    if(ch[1]) siz += ch[1]->siz, s += ch[1]->s; 
    if(ch[0] && ch[1]) {
      ls = ch[0]->ls + d * (ch[0]->siz + 1) + ch[1]->ls + ch[1]->s * (ch[0]->siz + 1); 
      rs = ch[1]->rs + d * (ch[1]->siz + 1) + ch[0]->rs + ch[0]->s * (ch[1]->siz + 1);
      s1 = ch[0]->s1 + ch[1]->s1;
      s1 += ch[0]->ls * (ch[1]->siz + 1);
      s1 += ch[1]->rs * (ch[0]->siz + 1); 
      s1 += d * (ch[0]->siz + 1) * (ch[1]->siz + 1);  
    } else if(ch[0]) {
      ls = ch[0]->ls + d * (ch[0]->siz + 1); 
      rs = d + ch[0]->rs + ch[0]->s; 
      s1 = ch[0]->s1 + ch[0]->ls + d * (ch[0]->siz + 1); 
    } else if(ch[1]) {
      ls = d + ch[1]->ls + ch[1]->s; 
      rs = d * (ch[1]->siz + 1) + ch[1]->rs; 
      s1 = ch[1]->s1 + ch[1]->rs + d * (ch[1]->siz + 1); 
    } else {
      ls = rs = s1 = d; 
    }
  }

  void push() {
    if(rev) {
      if(ch[0]) ch[0]->setr();
      if(ch[1]) ch[1]->setr(); 
      rev = 0; 
    }
    if(add) {
      if(ch[0]) ch[0]->seta(add);
      if(ch[1]) ch[1]->seta(add); 
      add = 0; 
    }
  }

} pool[N * 2], *P[N], *cur = pool;

node *New(ll d) { 
  node *p = cur++; 
  p->d = d, p->ls = p->rs = d; 
  p->s = p->s1 = d; 
  p->prt = p->ch[0] = p->ch[1] = 0; 
  p->siz = 1; 
  return p; 
}

void rotate(node *p) {
  node *prt = p->prt; int k = p->dir(); 
  if(!prt->isr()) prt->prt->setc(p, prt->dir()); 
  else p->prt = prt->prt; prt->setc(p->ch[!k], k); 
  p->setc(prt, !k); prt->upd(); p->upd();
}

node *sta[N]; int top; 
void splay(node *p) {
  node *q = p;
  while(1) { 
    sta[++top] = q; 
    if(q->isr()) break ;  
    q = q->prt; 
  } 
  while(top) 
    (sta[top--])->push(); 
  while(!p->isr()) {
    if(p->prt->isr()) rotate(p); 
    else if(p->dir() == p->prt->dir()) rotate(p->prt), rotate(p); 
    else rotate(p), rotate(p); 
  } p->upd(); 
}


node *access(node *p) {
  node *q = 0;
  for(; p; p = p->prt) {
    splay(p); 
    p->ch[1] = q; 
    (q = p)->upd(); 
  } return q; 
}

inline void mkroot(node *p) { access(p); splay(p); p->setr(); p->push(); }
inline void split (node *p, node *q) { mkroot(p); access(q); splay(p); }
inline void link  (node *p, node *q) { mkroot(p); mkroot(q); q->prt = p; }
inline void cut   (node *p, node *q) { split(p, q); p->ch[1] = q->prt = 0; }
inline node *find(node *p) { access(p); splay(p); while(p->ch[0]) p = p->ch[0]; return p; }

inline ll gcd(ll a, ll b) {
  return !b ? a : gcd(b, a % b); 
}

int main() {
  scanf("%d %d", &n, &m);
  for(int i = 1; i <= n; i++) {
    scanf("%lld", &a[i]); 
    P[i] = New(a[i]); 
  }
  for(int i = 1; i < n; i++) {
    int u, v; scanf("%d %d", &u, &v);
    link(P[u], P[v]); 
  } 
  for(int i = 1; i <= m; i++) {
    int op, u, v; ll d;  
    scanf("%d %d %d", &op, &u, &v); 
    if(op == 1) if(find(P[u]) == find(P[v])) cut(P[u], P[v]);
    if(op == 2) if(find(P[u]) != find(P[v])) link(P[u], P[v]); 
    if(op == 3) {
      scanf("%lld", &d); 
      if(find(P[u]) != find(P[v])) continue ; // important!
      split(P[u], P[v]), P[u]->seta(d); 
    }
    if(op == 4) {
      if(find(P[u]) != find(P[v])) {
        printf("-1\n"); 
        continue ; 
      }
      split(P[u], P[v]); 
      ll ans = P[u]->s1; 
      ll t = P[u]->siz * (P[u]->siz + 1) / 2; 
      ll g = gcd(ans, t); 
      printf("%lld/%lld\n", ans / g, t / g); 
    }
  }
  return 0; 
}
posted @ 2019-03-03 14:30  AcFunction  阅读(171)  评论(0编辑  收藏  举报