LeetCode LCP 05. 发 LeetCoin DFS序+带懒惰标记的线段树

题目描述

力扣决定给一个刷题团队发 LeetCoin 作为奖励。同时,为了监控给大家发了多少 LeetCoin,力扣有时候也会进行查询。

该刷题团队的管理模式可以用一棵树表示:

  1. 团队只有一个负责人,编号为 1。除了该负责人外,每个人有且仅有一个领导(负责人没有领导);
  2. 不存在循环管理的情况,如 A 管理 B,B 管理 C,C 管理 A。

力扣想进行的操作有以下三种:

  1. 给团队的一个成员(也可以是负责人)发一定数量的 LeetCoin
  2. 给团队的一个成员(也可以是负责人),以及他/她管理的所有人(即他/她的下属、他/她下属的下属,……),发一定数量的 LeetCoin
  3. 查询某一个成员(也可以是负责人),以及他/她管理的所有人被发到的 LeetCoin 之和。

输入

  1. N 表示团队成员的个数(编号为 1 ~ N,负责人为 1);
  2. leadership 是大小为 (N - 1) * 2 的二维数组,其中每个元素 [a, b] 代表 ba 的下属;
  3. operations是一个长度为 Q 的二维数组,代表以时间排序的操作,格式如下:
    • operations[i][0] = 1: 代表第一种操作,operations[i][1]代表成员的编号,operations[i][2] 代表 LeetCoin 的数量;
    • operations[i][0] = 2: 代表第二种操作,operations[i][1] 代表成员的编号,operations[i][2] 代表 LeetCoin 的数量;
    • operations[i][0] = 3: 代表第三种操作,operations[i][1] 代表成员的编号;

输出

返回一个数组,数组里是每次查询的返回值(发 LeetCoin 的操作不需要任何返回值)。

由于发的 LeetCoin 很多,请把每次查询的结果模 1e9+7 (1000000007)。

样例

输入:N = 6, leadership = [[1, 2], [1, 6], [2, 3], [2, 5], [1, 4]],
operations = [[1, 1, 500], [2, 2, 50], [3, 1], [2, 6, 15], [3, 1]]
输出:[650, 665]
解释:团队的管理关系见下图。
第一次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 0;
第二次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 15.

限制

  1. 1 <= N <= 50000
  2. 1 <= Q <= 50000
  3. operations[i][0] != 3 时,1 <= operations[i][2] <= 5000

算法

(DFS 序 + 线段树) \(O(n + Qlog n)\)
  1. 首先,将树结构转化为一维序列结构,将子树更新和子树查询转为区间更新和区间查询DFS 序可以解决此问题。具体做法:维护一个时间戳,时间戳从 0 开始。当新递归进入一个结点时,时间戳加 1,且这个结点的 left 记录为更新后的时间戳。然后深度优先递归遍历子树,当这个结点要回溯的时候,记录该结点的 right 为当前的时间戳。这样之后,每个结点就有了一个 left 值和一个 right 值。

    void dfs(int u, int& ts, const vector<vector<int>>& graph,
             vector<int>& left, vector<int> &right) {
        left[u] = ++ts;
        for (int v : graph[x])
            dfs(v, ts, graph, left, right);
    
        right[u] = ts;
    }
    
  2. DFS 序以某个结点为根的子树(包括自己)可以用子树根结点的 leftright 值来表示。以样例为例,整个子树的 DFS 序为 (1, 2, 3, 5, 6, 4),结点 2left 为 2,right 为 4,所以区间 [2, 4] 就是以结点 2 为根的子树。

  3. 带懒惰标记的线段树可以实现区间更新与区间查询。

时间复杂度

  • 建立 DFS 序的时间复杂度为 \(O(n)\)
  • 线段树的建立需要 \(O(n)\) 的时间复杂度,线段树的单次更新和查询需要 \(O(log n)\) 的时间复杂度。
  • 故总时间复杂度为 \(O(n + Qlog n)\)

空间复杂度

  • 线段树和其他数组占用的空间为 \(O(n)\)

C++ 代码

const int MOD = 1000000007;
const int N = 50010;
class Solution {
public:
    struct Node{
        int l, r;
        int sum, add;
    }tr[N << 2];

    void dfs(int x, int& ts, const vector<vector<int>>& graph, vector<int>& left, vector<int> &right) {
        left[x] = ++ts;
        for (int v : graph[x])
            dfs(v, ts, graph, left, right);

        right[x] = ts;
    }
    void build(int u, int l, int r){
        if(l == r) tr[u] = {l ,r, 0, 0};
        else {
            tr[u] = {l, r, 0, 0};
            int mid = (l + r) >> 1;
            build(u<<1, l, mid), build(u<<1|1, mid + 1, r);
            pushup(u);
        }
    }
    void pushdown(int u){
        if(!tr[u].add) return ;
        tr[u<<1].add = (tr[u<<1].add + tr[u].add) % MOD;
        tr[u<<1].sum =  (tr[u<<1].sum + (tr[u<<1].r - tr[u<<1].l + 1) * tr[u].add % MOD) % MOD;
        tr[u<<1|1].add = (tr[u<<1|1].add + tr[u].add) % MOD;
        tr[u<<1|1].sum =  (tr[u<<1|1].sum + (tr[u<<1|1].r - tr[u<<1|1].l + 1) * tr[u].add % MOD) % MOD;
        tr[u].add = 0;
        
    }
    void pushup(int u){
        tr[u].sum = (tr[u<<1].sum + tr[u<<1|1].sum) % MOD;
    }

    int query(int u, int l, int r){
        if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
        
        pushdown(u);
        int mid = (tr[u].l + tr[u].r) >> 1;
        int sum = 0;
        if(l <= mid) sum = query(u<<1, l, r);
        if(r > mid) sum = (sum + query(u<<1|1, l, r)) % MOD;
        return sum;
    }

    void modify(int u, int l, int r, int add){
        if(tr[u].l >= l && tr[u].r <= r) {
            tr[u].sum = (tr[u].sum + (tr[u].r - tr[u].l + 1) * add % MOD) % MOD;
            tr[u].add = (tr[u].add + add) % MOD;
            return;
        }
        pushdown(u);
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) modify(u<<1, l, r, add);
        if(r > mid)  modify(u<<1|1, l, r, add);
        pushup(u);
    }

    vector<int> bonus(int n, vector<vector<int>>& leadership, vector<vector<int>>& operations) {
        vector<vector<int>> graph(n + 1);

        for (const auto &v : leadership)
            graph[v[0]].push_back(v[1]);

        vector<int> left(n + 1), right(n + 1);

        int ts = 0;
        dfs(1, ts, graph, left, right);
        
        build(1, 1, n);
        vector<int> ans;
        for (const auto &op : operations) {
            if (op[0] == 1) modify(1, left[op[1]], left[op[1]], op[2]);
            else if (op[0] == 2) modify(1, left[op[1]], right[op[1]], op[2]);
            else ans.push_back(query(1, left[op[1]], right[op[1]]));
        }
        return ans;
    }
};

指针写法:

#define MOD 1000000007
#define LL long long

struct Node {
    int sum, lazy;
    Node *l, *r;
    Node(): sum(0), lazy(0), l(NULL), r(NULL){}
};

class Solution {
public:
    void dfs(int x, int& ts, const vector<vector<int>>& graph,
             vector<int>& left, vector<int> &right) {
        left[x] = ++ts;
        for (int v : graph[x])
            dfs(v, ts, graph, left, right);

        right[x] = ts;
    }

    Node* build(int L, int R) {
        Node *ret = new Node();
        if (L == R)
            return ret;

        int mid = (L + R) >> 1;
        ret -> l = build(L, mid);
        ret -> r = build(mid + 1, R);

        return ret;
    }

    void pushdown(Node *cur, int sz) {
        cur -> l -> lazy = (cur -> l -> lazy + cur -> lazy) % MOD;
        cur -> r -> lazy = (cur -> r -> lazy + cur -> lazy) % MOD;
        cur -> l -> sum =
            (cur -> l -> sum + (LL)(cur -> lazy) * (sz - (sz >> 1))) % MOD;
        cur -> r -> sum =
            (cur -> r -> sum + (LL)(cur -> lazy) * (sz >> 1)) % MOD;
        cur -> lazy = 0;
    }

    int query(int L, int R, int l, int r, Node *cur) {
        if (L <= l && r <= R)
            return cur -> sum;

        pushdown(cur, r - l + 1);

        int mid = (l + r) >> 1, ret = 0;
        if (L <= mid) ret = (ret + query(L, R, l, mid, cur -> l)) % MOD;
        if (mid < R) ret = (ret + query(L, R, mid + 1, r, cur -> r)) % MOD;

        return ret;
    }

    void update(int p, int x, int l, int r, Node *cur) {
        if (l == r) {
            cur -> sum = (cur -> sum + x) % MOD;
            return;
        }

        pushdown(cur, r - l + 1);
        int mid = (l + r) >> 1;
        if (p <= mid) update(p, x, l, mid, cur -> l);
        else update(p, x, mid + 1, r, cur -> r);

        cur -> sum = (cur -> l -> sum + cur -> r -> sum) % MOD;
    }

    void update(int L, int R, int x, int l, int r, Node *cur) {
        if (L <= l && r <= R) {
            cur -> lazy = (cur -> lazy + x) % MOD;
            cur -> sum = (cur -> sum + (LL)(x) * (r - l + 1)) % MOD;
            return;
        }

        pushdown(cur, r - l + 1);

        int mid = (l + r) >> 1;
        if (L <= mid) update(L, R, x, l, mid, cur -> l);
        if (mid < R) update(L, R, x, mid + 1, r, cur -> r);

        cur -> sum = (cur -> l -> sum + cur -> r -> sum) % MOD;
    }

    vector<int> bonus(int n, vector<vector<int>>& leadership,
                      vector<vector<int>>& operations) {
        vector<vector<int>> graph(n + 1);

        for (const auto &v : leadership)
            graph[v[0]].push_back(v[1]);

        vector<int> left(n + 1), right(n + 1);

        int ts = 0;
        dfs(1, ts, graph, left, right);


        Node *root = build(1, n);

        vector<int> ans;
        for (const auto &op : operations) {
            if (op[0] == 1) update(left[op[1]], op[2], 1, n, root);
            else if (op[0] == 2) update(left[op[1]], right[op[1]], op[2], 1, n, root);
            else ans.push_back(query(left[op[1]], right[op[1]], 1, n, root));
        }

        return ans;
    }
};

posted @ 2022-02-20 18:56  pxlsdz  阅读(3580)  评论(0编辑  收藏  举报