poj_3321 线段树/树状数组
题目大意
一个果树(每个节点的分叉数目不固定)上有N个分叉点(包括最末的叶节点),则有N-1条边,将分叉点进行从1到N编号,每个分叉点上均可以结水果。开始的时候,每个分叉点都有一个水果,之后进行一系列操作,分为两种操作:
(1)Q x: 查询分叉点x以及x的子树上的水果的总数;
(2)C x: 更改分叉点x上的结果状态,即原来有水果变为没水果,没水果变为有水果
对于给定的每个Q操作,给出结果。
题目分析
典型的区间操作,但不过是单点更新,区间查询。对于区间操作,可以选用线段树或者树状数组,线段树可以适用于多种类型的区间操作,而树状数组适用于 单点更新,区间查询的操作。
线段树的应用范围比树状数组略广,二者均为O(nlog(n))的复杂度,但是树状数组的常数系数要小于线段树,且编程复杂度较低。
一、线段树解法
对于区间操作,首先需要定出点和区间。本题目中的区间并不很明显& 套用线段树的结构,线段树的叶子节点均为实际存在的节点,而叶子节点的祖先节点是为了加速更改查询操作而设的。因此可以将果树中的N个分支点映射到一个完整的线段树的N个叶节点,然后在这N个叶节点上构建线段树;而其他非叶子节点,均表示某几个连续叶子节点构成的区间;而在对果树上的某个非叶子分叉点及其子树上的水果总数查询的时候,可以将 该分叉点及其子树上所有节点视为一个连续的区间,这个区间并不和线段树中的节点区间对齐,但可以用它们表示---->这样,才查询的时候,就相当于在线段树上进行连续区间的查询操作。
对于更新操作,需要从线段树的根节点,一路向下递归查找到叶子节点,沿途经过的所有节点i,节点i代表的区间上的水果总数就要变化1.
那么,怎么将果树的各个节点构造出一个连续的区间呢? 可以利用dfs在访问每个果树节点时候的顺序来构造。详见代码。
二、树状数组解法
dfs时候,果树的每个节点都有一个开始时间和结束时间,这样N个节点,就形成一个2*N的序列A,对于叶子节点,其开始时间点和结束时间点在序列A中是连续的,形成一个长度为2的连续区间,对于非叶子节点,其开始时间和结束时间在序列A中不连续,但它们中间包含的点连续起来形成一个区间,该区间内的点和该节点的子树中的节点完全对应。
即序列A中由果树节点x的开始时间点索引start和结束时间点索引end,构成的连续的一个区间和果树中节点x及其子树所有节点完全对应。
这样,就可以转化为一个明显的区间,可以更新和查询。
实现(c++)
一、线段树
#define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<string.h> #include<vector> using namespace std; #define MAX_FORK_NUM 100005 #define MAX_NODE_NUM 4*MAX_FORK_NUM #define MAX(a, b) a> b? a :b #define MIN(a, b) a< b? a :b vector<vector<int> > gGraph; struct Node{ int beg; int end; int sum_apple; }; Node gNodes[MAX_NODE_NUM]; pair<int, int > gForkToInterval[MAX_FORK_NUM]; //果树的fork分支点对应到线段树的闭区间 int gForkToIntervalTreeLeave[MAX_FORK_NUM]; //果树的fork分支点对应到线段树的叶子节点 bool gForkHasApple[MAX_FORK_NUM]; void InitForkInterval(int fork, int& cur_pos){ gForkToIntervalTreeLeave[fork] = cur_pos; gForkToInterval[fork].first = cur_pos; cur_pos++; for (int i = 0; i < gGraph[fork].size(); i++){ InitForkInterval(gGraph[fork][i], cur_pos); } gForkToInterval[fork].second = cur_pos - 1; } void PushUp(int index){ int left = 2 * index + 1; int right = 2 * index + 2; gNodes[index].sum_apple = gNodes[left].sum_apple + gNodes[right].sum_apple; } void BuildTree(int beg, int end, int index){ gNodes[index].beg = beg; gNodes[index].end = end; if (beg == end){ gNodes[index].sum_apple = 1; return; } int left = 2 * index + 1; int right = 2 * index + 2; int mid = (beg + end) / 2; BuildTree(beg, mid, left); BuildTree(mid + 1, end, right); PushUp(index); } void InitTree(int n){ int u, v; gGraph.assign(n+1, vector<int>()); for (int i = 0; i < n - 1; i++){ scanf("%d %d", &u, &v); gGraph[u].push_back(v); } } int Query(int beg, int end, int index){ if (gNodes[index].beg >= beg && gNodes[index].end <= end){ return gNodes[index].sum_apple; } if (gNodes[index].beg > end || gNodes[index].end < beg){ return 0; } if (beg > end){ return 0; } int mid = (gNodes[index].beg + gNodes[index].end) / 2; int left = 2 * index + 1, right = 2 * index + 2; return Query(beg, MIN(mid, end), left) + Query(MAX(mid + 1, beg), end, right); } void Update(int beg, int end, int index, bool add){ if (gNodes[index].beg == beg && gNodes[index].end == end){ if (add) gNodes[index].sum_apple ++; else gNodes[index].sum_apple --; return; } if (gNodes[index].beg > end || gNodes[index].end < beg){ return; } if (beg > end){ return; } int mid = (gNodes[index].beg + gNodes[index].end) / 2; if (add) //在向下递归的时候就更新 gNodes[index].sum_apple++; else gNodes[index].sum_apple--; int left = 2 * index + 1, right = 2 * index + 2; Update(beg, MIN(mid, end), left, add); Update(MAX(mid + 1, beg), end, right, add); } void debug(int index){ printf("index = %d, beg = %d, end = %d, sum_apple = %d, \n", index, gNodes[index].beg, gNodes[index].end, gNodes[index].sum_apple); if (gNodes[index].beg == gNodes[index].end){ return; } int left = 2 * index + 1; int right = 2 * index + 2; debug(left); debug(right); } int main(){ int N, M; scanf("%d", &N); InitTree(N); int total_inter_tree_leave = 0; InitForkInterval(1, total_inter_tree_leave); memset(gForkHasApple, true, sizeof(gForkHasApple)); BuildTree(0, total_inter_tree_leave - 1, 0); /* debug(0); for (int i = 1; i <= N; i++){ printf("fork %d's interval = [%d, %d], fork's tree leave = %d\n", i, gForkToInterval[i].first, gForkToInterval[i].second ,gForkToIntervalTreeLeave[i]); } */ scanf("%d", &M); char op; int fork; for (int i = 0; i < M; i++){ getchar(); scanf("%c %d", &op, &fork); if (op == 'Q'){ printf("%d\n", Query(gForkToInterval[fork].first, gForkToInterval[fork].second, 0)); } else{ Update(gForkToIntervalTreeLeave[fork], gForkToIntervalTreeLeave[fork], 0, !gForkHasApple[fork]); gForkHasApple[fork] = !gForkHasApple[fork]; } } return 0; }
二、树状数组
#define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<string.h> #include<vector> using namespace std; #define MAX_FORK_NUM 100005 #define MAX_NODE_NUM 2*MAX_FORK_NUM #define MAX(a, b) a> b? a :b #define MIN(a, b) a< b? a :b vector<vector<int> > gGraph; //树状数组,适合于 更新时更新区间内的某个点,O(logn);查询时候查询 一个连续区间 O(logn)。而不适用于每次更新一整个区间 //树状数组利用 原始数组a[i]对应的树状数组 C[i] (i 从1 开始计数) //C[i] = a[i - 2^k + 1] + ... + a[i] (i >= 1), 2^k = lowbit(i) = i&(-i) //C[i] = Sum(i) - Sum(i - lowbit(i))!!!! //对于区间求和,有 a[i] + a[i + 1] + ... + a[j] = sum(j) - sum(i-1). 其中 sum(j) = a[1] + a[2] + ... + a[j] //sum(k) = C[N1] + C[N2] + .. C[Nm-1] + C[Nm],(Nm = k, N1 >= 1) 且 Nt-1 = Nt - lowbit(Nt) //log(n)的复杂度 //对于区间更新,当a[i]更新,有且仅有如下几项需要更新: //C[N1], C[N2], ... C[Nm]。 其中N1 = i, Ni+1 = Ni + lowbit(Ni) //log(n)的复杂度 int gLowBit[MAX_NODE_NUM]; int gStartIndex[MAX_FORK_NUM]; int gEndIndex[MAX_FORK_NUM]; bool gA[MAX_NODE_NUM]; int gC[MAX_NODE_NUM]; //初始化原始数组A和A的树状数组C //开始的时候每个fork点都有水果,因此gA均为1; 而C[i] = Sum(i) - Sum(i - lowbit(i))!!!! void InitArray(int n){ for (int i = 1; i <= n; i++){ gA[i] = 1; gC[i] = i - (i - gLowBit[i]);////C[i] = Sum(i) - Sum(i - lowbit(i))!!!! } } //初始化lowbit数组 void InitLowbit(int n){ for (int i = 1; i <= n; i++){ gLowBit[i] = i&(-i); } } //初始化序列,确定每个fork点的开始和结束index,便于在之后进行查找 //用dfs 遍历图即可 void InitSequence(int fork, int& index){ gStartIndex[fork] = index++; for (int i = 0; i < gGraph[fork].size(); i++){ InitSequence(gGraph[fork][i], index); } gEndIndex[fork] = index++; } //建树 void InitTree(int n){ int u, v; gGraph.assign(n + 1, vector<int>()); for (int i = 0; i < n - 1; i++){ scanf("%d %d", &u, &v); gGraph[u].push_back(v); } } //更新 void Update(int p, int n, bool add){ while (p <= n){ if (add){ gC[p] ++; } else{ gC[p] --; } p += gLowBit[p]; } } //查询 int Query(int p, int n){ int sum = 0; while (p >= 1){ sum += gC[p]; p -= gLowBit[p]; } return sum; } int main(){ int N, M; scanf("%d", &N); InitTree(N); int total_node = 2 * N; InitLowbit(total_node); InitArray(total_node); int index = 1; InitSequence(1, index); scanf("%d", &M); char op; int fork; for (int i = 0; i < M; i++){ getchar(); scanf("%c %d", &op, &fork); if (op == 'Q'){ int sum1 = Query(gStartIndex[fork] - 1, total_node), sum2 = Query(gEndIndex[fork], total_node); printf("%d\n", (sum2 - sum1) / 2); } else{ int start = gStartIndex[fork], end = gEndIndex[fork]; Update(start, total_node, !gA[start]); Update(end, total_node, !gA[end]); gA[start] = !gA[start]; gA[end] = !gA[end]; } } return 0; }