Handling Sum Queries After Update

Handling Sum Queries After Update

You are given two 0-indexed arrays nums1 and nums2 and a 2D array queries of queries. There are three types of queries:

1. For a query of type 1, queries[i] = [1, l, r] . Flip the values from $0$ to $1$ and from $1$ to $0$ in nums1 from index $l$ to index $r$. Both $l$ and $r$ are 0-indexed.
2. For a query of type 2, queries[i] = [2, p, 0] . For every index 0 <= i < n , set nums2[i] = nums2[i] + nums1[i] * p .
3. For a query of type 3, queries[i] = [3, 0, 0] . Find the sum of the elements in nums2 .

Return an array containing all the answers to the third type queries.

Example 1:

Input: nums1 = [1,0,1], nums2 = [0,0,0], queries = [[1,1,1],[2,1,0],[3,0,0]]
Output: [3]
Explanation: After the first query nums1 becomes [1,1,1]. After the second query, nums2 becomes [1,1,1], so the answer to the third query is 3. Thus, [3] is returned.

Example 2:

Input: nums1 = [1], nums2 = [5], queries = [[2,0,0],[3,0,0]]
Output: [5]
Explanation: After the first query, nums2 remains [5], so the answer to the second query is 5. Thus, [5] is returned.
 

 

解题思路

  真就连lc都打不动了,昨晚周赛前三题全是贪心思维题,被第二题卡了半个多钟心态都炸力,日常被思维题爆杀.jpg。最后一题一眼线段树,结果估计是太紧张也没多少时间,思路一直没理清,结果今天早上想了下就做出来了。

  本质就是求整个区间有多少个$1$,那么对于第$2$个操作,等价于对答案直接累加$s \times p$,其中$s$是当前整个区间$1$的个数。对于第$3$个操作就直接输出累计的答案好了(记得一开始把$\text{nums2}$中的数先累加起来)。

  怎么用线段树去维护区间的$01$翻转呢?如果要对整个区间$[l,r]$翻转,假设当前区间$[l,r]$有$s$个$1$,那么翻转后就变成了$r - l + 1 - s$个$1$。同时由于是区间修改,因此线段树要带个懒标记,每次对整个区间翻转时都对当前区间的懒标记异或$1$就好了(相当于模$2$加法),当需要向下传的时候如果懒标记为$1$,说明需要对两个儿子的区间进行翻转。因此线段树需要维护区间中$1$的个数以及懒标记。

  AC代码如下:

 1 class Solution {
 2 public:
 3     struct Node {
 4         int l, r, s, add;
 5     };
 6     
 7     vector<long long> handleQuery(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& queries) {
 8         int n = nums1.size();
 9         long long ret = 0;
10         for (auto &x : nums2) {
11             ret += x;
12         }
13         vector<Node> tr(n * 4);
14         function<void(int, int, int)> build = [&](int u, int l, int r) {
15             if (l == r) {
16                 tr[u] = {l, r, nums1[l - 1], 0};
17             }
18             else {
19                 int mid = l + r >> 1;
20                 build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
21                 tr[u] = {l, r, tr[u << 1].s + tr[u << 1 | 1].s, 0};
22             }
23         };
24         function<void(int)> pushdown = [&](int u) {
25             if (tr[u].add) {
26                 tr[u << 1].s = tr[u << 1].r - tr[u << 1].l + 1 - tr[u << 1].s;
27                 tr[u << 1].add ^= 1;
28                 tr[u << 1 | 1].s = tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1 - tr[u << 1 | 1].s;
29                 tr[u << 1 | 1].add ^= 1;
30                 tr[u].add = 0;
31             }
32         };
33         function<void(int, int, int)> modify = [&](int u, int l, int r) {
34             if (tr[u].l >= l && tr[u].r <= r) {
35                 tr[u].s = tr[u].r - tr[u].l + 1 - tr[u].s;
36                 tr[u].add ^= 1;
37             }
38             else {
39                 pushdown(u);
40                 int mid = tr[u].l + tr[u].r >> 1;
41                 if (l <= mid) modify(u << 1, l, r);
42                 if (r >= mid + 1) modify(u << 1 | 1, l, r);
43                 tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
44             }
45         };
46         function<int(int, int, int)> query = [&](int u, int l, int r) {
47             if (tr[u].l >= l && tr[u].r <= r) return tr[u].s;
48             else {
49                 pushdown(u);
50                 int mid = tr[u].l + tr[u].r >> 1, s = 0;
51                 if (l <= mid) s = query(u << 1, l, r);
52                 if (r >= mid + 1) s += query(u << 1 | 1, l, r);
53                 return s;
54             }
55         };
56         build(1, 1, n);
57         vector<long long> ans;
58         for (auto &p : queries) {
59             if (p[0] == 1) modify(1, p[1] + 1, p[2] + 1);
60             else if (p[0] == 2) ret += 1ll * query(1, 1, n) * p[1];
61             else ans.push_back(ret);
62         }
63         return ans;
64     }
65 };

  参考洛谷的一道题:P3870 [TJOI2009] 开关,也是$01$区间翻转问题。

  贴个AC代码:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 const int N = 1e5 + 10;
 5 
 6 struct Node {
 7     int l, r, s, add;
 8 }tr[N * 4];
 9 
10 void build(int u, int l, int r) {
11     if (l == r) {
12         tr[u] = {l, r, 0, 0};
13     }
14     else {
15         int mid = l + r >> 1;
16         build(u << 1, l, mid);
17         build(u << 1 | 1, mid + 1, r);
18         tr[u] = {l, r, 0, 0};
19     }
20 }
21 
22 void pushdown(int u) {
23     if (tr[u].add) {
24         tr[u << 1].s = tr[u << 1].r - tr[u << 1].l + 1 - tr[u << 1].s;
25         tr[u << 1].add ^= 1;
26         tr[u << 1 | 1].s = tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1 - tr[u << 1 | 1].s;
27         tr[u << 1 | 1].add ^= 1;
28         tr[u].add = 0;
29     }
30 }
31 
32 void modify(int u, int l, int r) {
33     if (tr[u].l >= l && tr[u].r <= r) {
34         tr[u].s = tr[u].r - tr[u].l + 1 - tr[u].s;
35         tr[u].add ^= 1;
36     }
37     else {
38         pushdown(u);
39         int mid = tr[u].l + tr[u].r >> 1;
40         if (l <= mid) modify(u << 1, l, r);
41         if (r >= mid + 1) modify(u << 1 | 1, l, r);
42         tr[u].s = tr[u << 1].s + tr[u << 1 | 1].s;
43     }
44 }
45 
46 int query(int u, int l, int r) {
47     if (tr[u].l >= l && tr[u].r <= r) return tr[u].s;
48     pushdown(u);
49     int mid = tr[u].l + tr[u].r >> 1, s = 0;
50     if (l <= mid) s = query(u << 1, l, r);
51     if (r >= mid + 1) s += query(u << 1 | 1, l, r);
52     return s;
53 }
54 
55 int main() {
56     int n, m;
57     scanf("%d %d", &n, &m);
58     build(1, 1, n);
59     while (m--) {
60         int op, l, r;
61         scanf("%d %d %d", &op, &l, &r);
62         if (!op) modify(1, l, r);
63         else printf("%d\n", query(1, l, r));
64     }
65     
66     return 0;
67 }
posted @ 2023-02-19 10:16  onlyblues  阅读(14)  评论(0编辑  收藏  举报
Web Analytics