【bzoj1095】[ZJOI2007]Hide 捉迷藏 动态点分治+堆
题目描述
捉迷藏 Jiajia和Wind是一对恩爱的夫妻,并且他们有很多孩子。某天,Jiajia、Wind和孩子们决定在家里玩捉迷藏游戏。他们的家很大且构造很奇特,由N个屋子和N-1条双向走廊组成,这N-1条走廊的分布使得任意两个屋子都互相可达。游戏是这样进行的,孩子们负责躲藏,Jiajia负责找,而Wind负责操纵这N个屋子的灯。在起初的时候,所有的灯都没有被打开。每一次,孩子们只会躲藏在没有开灯的房间中,但是为了增加刺激性,孩子们会要求打开某个房间的电灯或者关闭某个房间的电灯。为了评估某一次游戏的复杂性,Jiajia希望知道可能的最远的两个孩子的距离(即最远的两个关灯房间的距离)。 我们将以如下形式定义每一种操作: C(hange) i 改变第i个房间的照明状态,若原来打开,则关闭;若原来关闭,则打开。 G(ame) 开始一次游戏,查询最远的两个关灯房间的距离。
输入
第一行包含一个整数N,表示房间的个数,房间将被编号为1,2,3…N的整数。接下来N-1行每行两个整数a, b,表示房间a与房间b之间有一条走廊相连。接下来一行包含一个整数Q,表示操作次数。接着Q行,每行一个操作,如上文所示。
输出
对于每一个操作Game,输出一个非负整数到hide.out,表示最远的两个关灯房间的距离。若只有一个房间是关着灯的,输出0;若所有房间的灯都开着,输出-1。
样例输入
8
1 2
2 3
3 4
3 5
3 6
6 7
6 8
7
G
C 1
G
C 2
G
C 1
G
样例输出
4
3
3
4
题解
动态点分治 +堆
动态点分治:将点分治的上一层重心与下一层连边,可以得到一棵新树(点分树)。由于每次都是找重心,所以树高不超过$\log$,就可以使用各种数据结构维护各种子树信息。
考虑如果本题是静态的,只有一次查询该怎么做:求出以每个点为根的最长路径,即求 $|$所有节点的 $|$子节点的 $|$子树中的节点到父亲节点的最大值$|$ 的最大值和次大值的和$|$ 的最大值$|$。($|$为断句方法= =)
形象一点,求每个点的子树中的所有节点到父亲节点的距离的最大值$p1$;每个点求出它所有子节点的$p1$以及当前节点状态(存在则为0)中的最大的和次大的,加起来得到$p2$;所有节点的$p2$的最大值就是$p3$。
考虑带修改,多次查询:首先由于有修改,所以树高必须要有保证,所以选择动态树分治的点分树结构。
那么需要是用数据结构,支持查询最大值和次大值,使用3种堆:
$s1[]$:维护一个子树中所有节点到当前点的父亲节点的距离;
$s2[]$:维护一个点的所有子分治节点(点分树中子节点)的$s1$中的最大值,如果当前节点可用,则需要再增加一个$0$;
$s3$:维护所有节点的$s2$的最大值与次大值(如果存在)之和。
每次$s3$的最大值就是答案。
于是就可以自底向上修改路径上的$s1$和$s2$,并修改$s3$。具体实现较为复杂:需要消除下一级对上一级的影响,所以要先删除上一级,再插入上一级;需要实现可以删除的堆,于是需要维护两个堆,删除时将要删的数加入到辅助堆中,每次取堆顶时如果两堆堆顶相同则都弹出。
并且需要维护欧拉遍历序并使用RMQLCA支持$O(1)$查询LCA以保证时间复杂度。
总时间复杂度为$O(n\log^2n)$,空间复杂度为$O(n\log n)$。
#include <queue> #include <cstdio> #define N 100010 using namespace std; struct heap { priority_queue<int> A , B; void push(int x) {A.push(x);} void del(int x) {B.push(x);} int top() { while(!B.empty() && A.top() == B.top()) A.pop() , B.pop(); return A.top(); } int sum() { int a = top(); A.pop(); int b = top(); push(a); return a + b; } int size() {return A.size() - B.size();} }s1[N] , s2[N] , s3; int head[N] , to[N << 1] , next[N << 1] , cnt , vis[N] , deep[N] , pos[N] , md[20][N << 1]; int si[N] , mx[N] , sum , root , log[N << 1] , tot , fa[N] , val[N]; char str[5]; void insert(heap &s) {if(s.size() >= 2) s3.push(s.sum());} void erase(heap &s) {if(s.size() >= 2) s3.del(s.sum());} void add(int x , int y) { to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt; } void dfs(int x , int fa) { int i; md[0][++tot] = deep[x] , pos[x] = tot; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa) deep[to[i]] = deep[x] + 1 , dfs(to[i] , x) , md[0][++tot] = deep[x]; } int lca(int x , int y) { x = pos[x] , y = pos[y]; if(x > y) swap(x , y); int k = log[y - x + 1]; return min(md[k][x] , md[k][y - (1 << k) + 1]); } void getroot(int x , int fa) { int i; si[x] = 1 , mx[x] = 0; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != fa) getroot(to[i] , x) , si[x] += si[to[i]] , mx[x] = max(mx[x] , si[to[i]]); mx[x] = max(mx[x] , sum - si[x]); if(mx[x] < mx[root]) root = x; } void solve(int x) { int i; vis[x] = 1; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]]) sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , fa[root] = x , solve(root); } void join(int x) { erase(s2[x]) , s2[x].push(0) , insert(s2[x]); int t; for(t = x ; fa[t] ; t = fa[t]) { erase(s2[fa[t]]); if(s1[t].size()) s2[fa[t]].del(s1[t].top()); s1[t].push(deep[fa[t]] + deep[x] - 2 * lca(fa[t] , x)) , s2[fa[t]].push(s1[t].top()); insert(s2[fa[t]]); } } void remove(int x) { erase(s2[x]) , s2[x].del(0) , insert(s2[x]); int t; for(t = x ; fa[t] ; t = fa[t]) { erase(s2[fa[t]]); s2[fa[t]].del(s1[t].top()) , s1[t].del(deep[fa[t]] + deep[x] - 2 * lca(fa[t] , x)); if(s1[t].size()) s2[fa[t]].push(s1[t].top()); insert(s2[fa[t]]); } } int main() { int n , m , i , j , x , y , num; scanf("%d" , &n) , num = n; for(i = 1 ; i < n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x); dfs(1 , 0); for(i = 2 ; i <= tot ; i ++ ) log[i] = log[i >> 1] + 1; for(i = 1 ; (1 << i) <= tot ; i ++ ) for(j = 1 ; j <= tot - (1 << i) + 1 ; j ++ ) md[i][j] = min(md[i - 1][j] , md[i - 1][j + (1 << (i - 1))]); mx[0] = 1 << 30 , sum = n , getroot(1 , 0) , solve(root); for(i = 1 ; i <= n ; i ++ ) val[i] = 1 , join(i); scanf("%d" , &m); while(m -- ) { scanf("%s" , str); if(str[0] == 'G') { if(num >= 2) printf("%d\n" , s3.top()); else printf("%d\n" , num - 1); } else { scanf("%d" , &x); if(val[x]) num -- , val[x] = 0 , remove(x); else num ++ , val[x] = 1 , join(x); } } return 0; }