[COJ0968]WZJ的数据结构(负三十二)
[COJ0968]WZJ的数据结构(负三十二)
试题描述
给你一棵N个点的无根树,边上均有权值,每个点上有一盏灯,初始均亮着。请你设计一个数据结构,回答M次操作。
1 x:将节点x上的灯拉一次,即亮变灭,灭变亮。
2 x k:询问当前所有亮灯的节点中距离x第k小的距离(注意如果x亮着也算入)。
输入
第一行为一个正整数N。
第二行到第N行每行三个正整数ui,vi,wi。表示一条树边从ui到vi,距离为wi。
第N+1行为一个正整数M。
最后M行每行三个或两个正整数,格式见题面。
第二行到第N行每行三个正整数ui,vi,wi。表示一条树边从ui到vi,距离为wi。
第N+1行为一个正整数M。
最后M行每行三个或两个正整数,格式见题面。
输出
对于每个询问操作,输出答案。
输入示例
10 1 2 2 1 3 1 1 4 3 1 5 2 4 6 2 4 7 1 6 8 1 7 9 2 7 10 1 5 2 1 4 1 5 2 1 4 2 1 9 2 1 1
输出示例
2 3 6 0
数据规模及约定
1<=N,M<=50000
1<=x,ui,vi<=N,1<=v,wi<=1000
题解
动态点分治。对于每个节点我们开一个平衡树,每次修改节点 u 时把 u 以及它到根节点的路径上所有节点上的平衡树都更新一下;对于询问我们先二分答案 x,然后查找一下 u 到根节点路径上所有平衡树,看小于等于 x 的值是否小于 k 个。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 50010 #define maxm 100010 #define maxlog 17 int n, m, head[maxn], nxt[maxm], to[maxm], dist[maxm]; void AddEdge(int a, int b, int c) { to[++m] = b; dist[m] = c; nxt[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; dist[m] = c; nxt[m] = head[a]; head[a] = m; return ; } int dep[maxn], mnd[maxlog][maxn<<1], Log[maxn<<1], clo, pos[maxn]; void build(int u, int pa) { mnd[0][pos[u] = ++clo] = dep[u]; for(int e = head[u]; e; e = nxt[e]) if(to[e] != pa) dep[to[e]] = dep[u] + dist[e], build(to[e], u), mnd[0][++clo] = dep[u]; return ; } void rmq_init() { Log[1] = 0; for(int i = 2; i <= clo; i++) Log[i] = Log[i>>1] + 1; for(int j = 1; (1 << j) <= clo; j++) for(int i = 1; i + (1 << j) - 1 <= clo; i++) mnd[j][i] = min(mnd[j-1][i], mnd[j-1][i+(1<<j-1)]); return ; } int cdist(int a, int b) { int ans = dep[a] + dep[b]; int l = pos[a], r = pos[b]; if(l > r) swap(l, r); int t = Log[r-l+1]; return ans - (min(mnd[t][l], mnd[t][r-(1<<t)+1]) << 1); } int rt, size, siz[maxn], f[maxn]; bool vis[maxn]; void getrt(int u, int pa) { siz[u] = 1; f[u] = 0; for(int e = head[u]; e; e = nxt[e]) if(to[e] != pa && !vis[to[e]]) { getrt(to[e], u); siz[u] += siz[to[e]]; f[u] = max(f[u], siz[to[e]]); } f[u] = max(f[u], size - siz[u]); if(f[rt] > f[u]) rt = u; return ; } int fa[maxn]; void solve(int u) { vis[u] = 1; for(int e = head[u]; e; e = nxt[e]) if(!vis[to[e]]) { f[rt = 0] = size = siz[u]; getrt(to[e], u); fa[rt] = u; solve(rt); } return ; } #define maxnode 1600010 struct Node { int v, r, siz; Node() {} Node(int _, int __): v(_), r(__) {} } ns[maxnode]; int ToT, ch[maxnode][2], Fa[maxnode], rec[maxnode], rcnt; inline int getnode() { if(rcnt) { int o = rec[rcnt--]; ch[o][0] = ch[o][1] = Fa[o] = 0; return o; } return ++ToT; } inline void maintain(int o) { if(!o) return ; ns[o].siz = ns[ch[o][0]].siz + 1 + ns[ch[o][1]].siz; return ; } inline void rotate(int u) { int y = Fa[u], z = Fa[y], l = 0, r = 1; if(z) ch[z][ch[z][1]==y] = u; if(ch[y][1] == u) swap(l, r); Fa[u] = z; Fa[y] = u; Fa[ch[u][r]] = y; ch[y][l] = ch[u][r]; ch[u][r] = y; maintain(y); maintain(u); return ; } inline void Insert(int& o, int v) { if(!o) { ns[o = getnode()] = Node(v, rand()); return maintain(o); } bool d = v > ns[o].v; Insert(ch[o][d], v); Fa[ch[o][d]] = o; if(ns[ch[o][d]].r > ns[o].r) { int t = ch[o][d]; rotate(t); o = t; } return maintain(o); } inline void Del(int& o, int v) { if(!o) return ; if(ns[o].v == v) { if(!ch[o][0] && !ch[o][1]) rec[++rcnt] = o, o = 0; else if(!ch[o][0]) { int t = ch[o][1]; Fa[t] = Fa[o]; rec[++rcnt] = o; o = t; } else if(!ch[o][1]) { int t = ch[o][0]; Fa[t] = Fa[o]; rec[++rcnt] = o; o = t; } else { bool d = ns[ch[o][1]].r > ns[ch[o][0]].r; int t = ch[o][d]; rotate(t); o = t; Del(ch[o][d^1], v); } } else { bool d = v > ns[o].v; Del(ch[o][d], v); } return maintain(o); } inline int query(int o, int x) { if(!o) return 0; int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(x < ns[o].v) return query(ch[o][0], x); return ls + 1 + query(ch[o][1], x); } int Rt[maxn], Rtfa[maxn]; bool lit[maxn]; void update(int s) { if(lit[s]) Insert(Rt[s], 0); else Del(Rt[s], 0); for(int u = s; fa[u]; u = fa[u]) { int d = cdist(fa[u], s); if(lit[s]) Insert(Rt[fa[u]], d), Insert(Rtfa[u], d); else Del(Rt[fa[u]], d), Del(Rtfa[u], d); } lit[s] ^= 1; return ; } int ask(int s, int x) { int ans = query(Rt[s], x); for(int u = s; fa[u]; u = fa[u]) { int d = cdist(fa[u], s); ans += query(Rt[fa[u]], x - d) - query(Rtfa[u], x - d); } return ans; } int main() { n = read(); int sum = 0; for(int i = 1; i < n; i++) { int a = read(), b = read(), c = read(); sum += c; AddEdge(a, b, c); } build(1, 0); rmq_init(); f[rt = 0] = size = n; getrt(1, 0); solve(rt); memset(lit, 1, sizeof(lit)); for(int i = 1; i <= n; i++) update(i); int q = read(); while(q--) { int tp = read(), u = read(); if(tp == 1) update(u); if(tp == 2) { int k = read(); int l = 0, r = sum; while(l < r) { int mid = l + r >> 1; if(ask(u, mid) < k) l = mid + 1; else r = mid; } printf("%d\n", l); } } return 0; }