LeetCode LCP 05. 发 LeetCoin DFS序+带懒惰标记的线段树
题目描述
力扣决定给一个刷题团队发 LeetCoin
作为奖励。同时,为了监控给大家发了多少 LeetCoin
,力扣有时候也会进行查询。
该刷题团队的管理模式可以用一棵树表示:
- 团队只有一个负责人,编号为 1。除了该负责人外,每个人有且仅有一个领导(负责人没有领导);
- 不存在循环管理的情况,如 A 管理 B,B 管理 C,C 管理 A。
力扣想进行的操作有以下三种:
- 给团队的一个成员(也可以是负责人)发一定数量的
LeetCoin
; - 给团队的一个成员(也可以是负责人),以及他/她管理的所有人(即他/她的下属、他/她下属的下属,……),发一定数量的
LeetCoin
; - 查询某一个成员(也可以是负责人),以及他/她管理的所有人被发到的
LeetCoin
之和。
输入
N
表示团队成员的个数(编号为1 ~ N
,负责人为 1);leadership
是大小为(N - 1) * 2
的二维数组,其中每个元素[a, b]
代表b
是a
的下属;operations
是一个长度为Q
的二维数组,代表以时间排序的操作,格式如下:operations[i][0] = 1
: 代表第一种操作,operations[i][1]
代表成员的编号,operations[i][2]
代表LeetCoin
的数量;operations[i][0] = 2
: 代表第二种操作,operations[i][1]
代表成员的编号,operations[i][2]
代表LeetCoin
的数量;operations[i][0] = 3
: 代表第三种操作,operations[i][1]
代表成员的编号;
输出
返回一个数组,数组里是每次查询的返回值(发 LeetCoin
的操作不需要任何返回值)。
由于发的 LeetCoin
很多,请把每次查询的结果模 1e9+7
(1000000007
)。
样例
输入:N = 6, leadership = [[1, 2], [1, 6], [2, 3], [2, 5], [1, 4]],
operations = [[1, 1, 500], [2, 2, 50], [3, 1], [2, 6, 15], [3, 1]]
输出:[650, 665]
解释:团队的管理关系见下图。
第一次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 0;
第二次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 15.
限制
1 <= N <= 50000
1 <= Q <= 50000
operations[i][0] != 3
时,1 <= operations[i][2] <= 5000
算法
(DFS 序 + 线段树) \(O(n + Qlog n)\)
-
首先,将树结构转化为一维序列结构,将子树更新和子树查询转为区间更新和区间查询,
DFS
序可以解决此问题。具体做法:维护一个时间戳,时间戳从 0 开始。当新递归进入一个结点时,时间戳加 1,且这个结点的left
记录为更新后的时间戳。然后深度优先递归遍历子树,当这个结点要回溯的时候,记录该结点的right
为当前的时间戳。这样之后,每个结点就有了一个left
值和一个right
值。void dfs(int u, int& ts, const vector<vector<int>>& graph, vector<int>& left, vector<int> &right) { left[u] = ++ts; for (int v : graph[x]) dfs(v, ts, graph, left, right); right[u] = ts; }
-
DFS
序以某个结点为根的子树(包括自己)可以用子树根结点的left
和right
值来表示。以样例为例,整个子树的DFS
序为(1, 2, 3, 5, 6, 4)
,结点2
的left
为 2,right
为 4,所以区间[2, 4]
就是以结点2
为根的子树。 -
带懒惰标记的线段树可以实现区间更新与区间查询。
时间复杂度
- 建立 DFS 序的时间复杂度为 \(O(n)\)。
- 线段树的建立需要 \(O(n)\) 的时间复杂度,线段树的单次更新和查询需要 \(O(log n)\) 的时间复杂度。
- 故总时间复杂度为 \(O(n + Qlog n)\)。
空间复杂度
- 线段树和其他数组占用的空间为 \(O(n)\)。
C++ 代码
const int MOD = 1000000007;
const int N = 50010;
class Solution {
public:
struct Node{
int l, r;
int sum, add;
}tr[N << 2];
void dfs(int x, int& ts, const vector<vector<int>>& graph, vector<int>& left, vector<int> &right) {
left[x] = ++ts;
for (int v : graph[x])
dfs(v, ts, graph, left, right);
right[x] = ts;
}
void build(int u, int l, int r){
if(l == r) tr[u] = {l ,r, 0, 0};
else {
tr[u] = {l, r, 0, 0};
int mid = (l + r) >> 1;
build(u<<1, l, mid), build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void pushdown(int u){
if(!tr[u].add) return ;
tr[u<<1].add = (tr[u<<1].add + tr[u].add) % MOD;
tr[u<<1].sum = (tr[u<<1].sum + (tr[u<<1].r - tr[u<<1].l + 1) * tr[u].add % MOD) % MOD;
tr[u<<1|1].add = (tr[u<<1|1].add + tr[u].add) % MOD;
tr[u<<1|1].sum = (tr[u<<1|1].sum + (tr[u<<1|1].r - tr[u<<1|1].l + 1) * tr[u].add % MOD) % MOD;
tr[u].add = 0;
}
void pushup(int u){
tr[u].sum = (tr[u<<1].sum + tr[u<<1|1].sum) % MOD;
}
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
int sum = 0;
if(l <= mid) sum = query(u<<1, l, r);
if(r > mid) sum = (sum + query(u<<1|1, l, r)) % MOD;
return sum;
}
void modify(int u, int l, int r, int add){
if(tr[u].l >= l && tr[u].r <= r) {
tr[u].sum = (tr[u].sum + (tr[u].r - tr[u].l + 1) * add % MOD) % MOD;
tr[u].add = (tr[u].add + add) % MOD;
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(l <= mid) modify(u<<1, l, r, add);
if(r > mid) modify(u<<1|1, l, r, add);
pushup(u);
}
vector<int> bonus(int n, vector<vector<int>>& leadership, vector<vector<int>>& operations) {
vector<vector<int>> graph(n + 1);
for (const auto &v : leadership)
graph[v[0]].push_back(v[1]);
vector<int> left(n + 1), right(n + 1);
int ts = 0;
dfs(1, ts, graph, left, right);
build(1, 1, n);
vector<int> ans;
for (const auto &op : operations) {
if (op[0] == 1) modify(1, left[op[1]], left[op[1]], op[2]);
else if (op[0] == 2) modify(1, left[op[1]], right[op[1]], op[2]);
else ans.push_back(query(1, left[op[1]], right[op[1]]));
}
return ans;
}
};
指针写法:
#define MOD 1000000007
#define LL long long
struct Node {
int sum, lazy;
Node *l, *r;
Node(): sum(0), lazy(0), l(NULL), r(NULL){}
};
class Solution {
public:
void dfs(int x, int& ts, const vector<vector<int>>& graph,
vector<int>& left, vector<int> &right) {
left[x] = ++ts;
for (int v : graph[x])
dfs(v, ts, graph, left, right);
right[x] = ts;
}
Node* build(int L, int R) {
Node *ret = new Node();
if (L == R)
return ret;
int mid = (L + R) >> 1;
ret -> l = build(L, mid);
ret -> r = build(mid + 1, R);
return ret;
}
void pushdown(Node *cur, int sz) {
cur -> l -> lazy = (cur -> l -> lazy + cur -> lazy) % MOD;
cur -> r -> lazy = (cur -> r -> lazy + cur -> lazy) % MOD;
cur -> l -> sum =
(cur -> l -> sum + (LL)(cur -> lazy) * (sz - (sz >> 1))) % MOD;
cur -> r -> sum =
(cur -> r -> sum + (LL)(cur -> lazy) * (sz >> 1)) % MOD;
cur -> lazy = 0;
}
int query(int L, int R, int l, int r, Node *cur) {
if (L <= l && r <= R)
return cur -> sum;
pushdown(cur, r - l + 1);
int mid = (l + r) >> 1, ret = 0;
if (L <= mid) ret = (ret + query(L, R, l, mid, cur -> l)) % MOD;
if (mid < R) ret = (ret + query(L, R, mid + 1, r, cur -> r)) % MOD;
return ret;
}
void update(int p, int x, int l, int r, Node *cur) {
if (l == r) {
cur -> sum = (cur -> sum + x) % MOD;
return;
}
pushdown(cur, r - l + 1);
int mid = (l + r) >> 1;
if (p <= mid) update(p, x, l, mid, cur -> l);
else update(p, x, mid + 1, r, cur -> r);
cur -> sum = (cur -> l -> sum + cur -> r -> sum) % MOD;
}
void update(int L, int R, int x, int l, int r, Node *cur) {
if (L <= l && r <= R) {
cur -> lazy = (cur -> lazy + x) % MOD;
cur -> sum = (cur -> sum + (LL)(x) * (r - l + 1)) % MOD;
return;
}
pushdown(cur, r - l + 1);
int mid = (l + r) >> 1;
if (L <= mid) update(L, R, x, l, mid, cur -> l);
if (mid < R) update(L, R, x, mid + 1, r, cur -> r);
cur -> sum = (cur -> l -> sum + cur -> r -> sum) % MOD;
}
vector<int> bonus(int n, vector<vector<int>>& leadership,
vector<vector<int>>& operations) {
vector<vector<int>> graph(n + 1);
for (const auto &v : leadership)
graph[v[0]].push_back(v[1]);
vector<int> left(n + 1), right(n + 1);
int ts = 0;
dfs(1, ts, graph, left, right);
Node *root = build(1, n);
vector<int> ans;
for (const auto &op : operations) {
if (op[0] == 1) update(left[op[1]], op[2], 1, n, root);
else if (op[0] == 2) update(left[op[1]], right[op[1]], op[2], 1, n, root);
else ans.push_back(query(left[op[1]], right[op[1]], 1, n, root));
}
return ans;
}
};