[COJ0968]WZJ的数据结构(负三十二)

[COJ0968]WZJ的数据结构(负三十二)

试题描述

给你一棵N个点的无根树,边上均有权值,每个点上有一盏灯,初始均亮着。请你设计一个数据结构,回答M次操作。

1 x:将节点x上的灯拉一次,即亮变灭,灭变亮。

2 x k:询问当前所有亮灯的节点中距离x第k小的距离(注意如果x亮着也算入)。

输入

第一行为一个正整数N。
第二行到第N行每行三个正整数ui,vi,wi。表示一条树边从ui到vi,距离为wi。
第N+1行为一个正整数M。
最后M行每行三个或两个正整数,格式见题面。

输出

对于每个询问操作,输出答案。

输入示例

10
1 2 2
1 3 1
1 4 3
1 5 2
4 6 2
4 7 1
6 8 1
7 9 2
7 10 1
5
2 1 4
1 5
2 1 4
2 1 9
2 1 1

输出示例

2
3
6
0

数据规模及约定

1<=N,M<=50000
1<=x,ui,vi<=N,1<=v,wi<=1000

题解

动态点分治。对于每个节点我们开一个平衡树,每次修改节点 u 时把 u 以及它到根节点的路径上所有节点上的平衡树都更新一下;对于询问我们先二分答案 x,然后查找一下 u 到根节点路径上所有平衡树,看小于等于 x 的值是否小于 k 个。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 50010
#define maxm 100010
#define maxlog 17

int n, m, head[maxn], nxt[maxm], to[maxm], dist[maxm];

void AddEdge(int a, int b, int c) {
	to[++m] = b; dist[m] = c; nxt[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; dist[m] = c; nxt[m] = head[a]; head[a] = m;
	return ;
}

int dep[maxn], mnd[maxlog][maxn<<1], Log[maxn<<1], clo, pos[maxn];
void build(int u, int pa) {
	mnd[0][pos[u] = ++clo] = dep[u];
	for(int e = head[u]; e; e = nxt[e]) if(to[e] != pa)
		dep[to[e]] = dep[u] + dist[e], build(to[e], u), mnd[0][++clo] = dep[u];
	return ;
}
void rmq_init() {
	Log[1] = 0;
	for(int i = 2; i <= clo; i++) Log[i] = Log[i>>1] + 1;
	for(int j = 1; (1 << j) <= clo; j++)
		for(int i = 1; i + (1 << j) - 1 <= clo; i++)
			mnd[j][i] = min(mnd[j-1][i], mnd[j-1][i+(1<<j-1)]);
	return ;
}
int cdist(int a, int b) {
	int ans = dep[a] + dep[b];
	int l = pos[a], r = pos[b]; if(l > r) swap(l, r);
	int t = Log[r-l+1];
	return ans - (min(mnd[t][l], mnd[t][r-(1<<t)+1]) << 1);
}

int rt, size, siz[maxn], f[maxn];
bool vis[maxn];
void getrt(int u, int pa) {
	siz[u] = 1; f[u] = 0;
	for(int e = head[u]; e; e = nxt[e]) if(to[e] != pa && !vis[to[e]]) {
		getrt(to[e], u);
		siz[u] += siz[to[e]];
		f[u] = max(f[u], siz[to[e]]);
	}
	f[u] = max(f[u], size - siz[u]);
	if(f[rt] > f[u]) rt = u;
	return ;
}
int fa[maxn];
void solve(int u) {
	vis[u] = 1;
	for(int e = head[u]; e; e = nxt[e]) if(!vis[to[e]]) {
		f[rt = 0] = size = siz[u]; getrt(to[e], u);
		fa[rt] = u; solve(rt);
	}
	return ;
}

#define maxnode 1600010

struct Node {
	int v, r, siz;
	Node() {}
	Node(int _, int __): v(_), r(__) {}
} ns[maxnode];
int ToT, ch[maxnode][2], Fa[maxnode], rec[maxnode], rcnt;
inline int getnode() {
	if(rcnt) {
		int o = rec[rcnt--];
		ch[o][0] = ch[o][1] = Fa[o] = 0;
		return o;
	}
	return ++ToT;
}
inline void maintain(int o) {
	if(!o) return ;
	ns[o].siz = ns[ch[o][0]].siz + 1 + ns[ch[o][1]].siz;
	return ;
}
inline void rotate(int u) {
	int y = Fa[u], z = Fa[y], l = 0, r = 1;
	if(z) ch[z][ch[z][1]==y] = u;
	if(ch[y][1] == u) swap(l, r);
	Fa[u] = z; Fa[y] = u; Fa[ch[u][r]] = y;
	ch[y][l] = ch[u][r]; ch[u][r] = y;
	maintain(y); maintain(u);
	return ;
}
inline void Insert(int& o, int v) {
	if(!o) {
		ns[o = getnode()] = Node(v, rand());
		return maintain(o);
	}
	bool d = v > ns[o].v;
	Insert(ch[o][d], v); Fa[ch[o][d]] = o;
	if(ns[ch[o][d]].r > ns[o].r) {
		int t = ch[o][d];
		rotate(t); o = t;
	}
	return maintain(o);
}
inline void Del(int& o, int v) {
	if(!o) return ;
	if(ns[o].v == v) {
		if(!ch[o][0] && !ch[o][1]) rec[++rcnt] = o, o = 0;
		else if(!ch[o][0]) {
			int t = ch[o][1]; Fa[t] = Fa[o]; rec[++rcnt] = o; o = t;
		}
		else if(!ch[o][1]) {
			int t = ch[o][0]; Fa[t] = Fa[o]; rec[++rcnt] = o; o = t;
		}
		else {
			bool d = ns[ch[o][1]].r > ns[ch[o][0]].r;
			int t = ch[o][d]; rotate(t); o = t;
			Del(ch[o][d^1], v);
		}
	}
	else {
		bool d = v > ns[o].v;
		Del(ch[o][d], v);
	}
	return maintain(o);
}
inline int query(int o, int x) {
	if(!o) return 0;
	int ls = ch[o][0] ? ns[ch[o][0]].siz : 0;
	if(x < ns[o].v) return query(ch[o][0], x);
	return ls + 1 + query(ch[o][1], x);
}

int Rt[maxn], Rtfa[maxn];
bool lit[maxn];
void update(int s) {
	if(lit[s]) Insert(Rt[s], 0);
	else Del(Rt[s], 0);
	for(int u = s; fa[u]; u = fa[u]) {
		int d = cdist(fa[u], s);
		if(lit[s]) Insert(Rt[fa[u]], d), Insert(Rtfa[u], d);
		else Del(Rt[fa[u]], d), Del(Rtfa[u], d);
	}
	lit[s] ^= 1;
	return ;
}
int ask(int s, int x) {
	int ans = query(Rt[s], x);
	for(int u = s; fa[u]; u = fa[u]) {
		int d = cdist(fa[u], s);
		ans += query(Rt[fa[u]], x - d) - query(Rtfa[u], x - d);
	}
	return ans;
}

int main() {
	n = read();
	int sum = 0;
	for(int i = 1; i < n; i++) {
		int a = read(), b = read(), c = read(); sum += c;
		AddEdge(a, b, c);
	}
	
	build(1, 0); rmq_init();
	f[rt = 0] = size = n; getrt(1, 0);
	solve(rt);
	
	memset(lit, 1, sizeof(lit));
	for(int i = 1; i <= n; i++) update(i);
	int q = read();
	while(q--) {
		int tp = read(), u = read();
		if(tp == 1) update(u);
		if(tp == 2) {
			int k = read();
			int l = 0, r = sum;
			while(l < r) {
				int mid = l + r >> 1;
				if(ask(u, mid) < k) l = mid + 1; else r = mid;
			}
			printf("%d\n", l);
		}
	}
	
	return 0;
}

 

posted @ 2017-03-22 07:33  xjr01  阅读(211)  评论(0编辑  收藏  举报