LightOJ 1135 - Count the Multiples of 3 线段树
http://www.lightoj.com/volume_showproblem.php?problem=1135
题意:给定两个操作,一个对区间所有元素加1,一个询问区间能被3整除的数有多少个。
思路:要求被3整除,我们可以记录3个状态,当前区间模3余1的 余2的 余0的,那么对一个数增加的时候,直接交换不同余数下的个数就可以了。
/** @Date : 2016-12-06-20.00 * @Author : Lweleth (SoungEarlf@gmail.com) * @Link : https://github.com/ * @Version : */ #include<bits/stdc++.h> #define LL long long #define PII pair #define MP(x, y) make_pair((x),(y)) #define fi first #define se second #define PB(x) push_back((x)) #define MMG(x) memset((x), -1,sizeof(x)) #define MMF(x) memset((x),0,sizeof(x)) #define MMI(x) memset((x), INF, sizeof(x)) using namespace std; const int INF = 0x3f3f3f3f; const int N = 1e5+20; const double eps = 1e-8; struct yuu { int l, r; int add; int m0, m1, m2; }tt[N << 2]; void pushup(int p) { tt[p].m0 = tt[p << 1].m0 + tt[p << 1 | 1].m0; tt[p].m1 = tt[p << 1].m1 + tt[p << 1 | 1].m1; tt[p].m2 = tt[p << 1].m2 + tt[p << 1 | 1].m2; } void pushdown(int p) { if(tt[p].add != 0) { tt[p].add %= 3; /// tt[p << 1].add += tt[p].add; if(tt[p].add == 2) { swap(tt[p << 1].m0 , tt[p << 1].m1); swap(tt[p << 1].m0 , tt[p << 1].m2); } else if(tt[p].add == 1) { swap(tt[p << 1].m0 , tt[p << 1].m2); swap(tt[p << 1].m1 , tt[p << 1].m0); } /// tt[p << 1 | 1].add += tt[p].add; if(tt[p].add == 2) { swap(tt[p << 1 | 1].m0 , tt[p << 1 | 1].m1); swap(tt[p << 1 | 1].m0 , tt[p << 1 | 1].m2); } else if(tt[p].add == 1) { swap(tt[p << 1 | 1].m0 , tt[p << 1 | 1].m2); swap(tt[p << 1 | 1].m1 , tt[p << 1 | 1].m0); } tt[p].add = 0; } } void build(int l, int r, int p) { tt[p].l = l; tt[p].r = r; tt[p].add = tt[p].m0 = tt[p].m2 = tt[p].m1 = 0; if(l == r) { tt[p].m0 = 1; return ; } int mid = (l + r) >> 1; build(l , mid, p << 1); build(mid + 1, r, p << 1 | 1); pushup(p); } void updata(int l, int r, int v, int p) { if(l <= tt[p].l && r >= tt[p].r) { tt[p].add += v; swap(tt[p].m0 , tt[p].m2); swap(tt[p].m1 , tt[p].m0); return ; } pushdown(p); int mid = (tt[p].l + tt[p].r) >> 1; if(l <= mid) updata(l, r, v, p << 1); if(r > mid) updata(l, r, v, p << 1 | 1); pushup(p); } int query(int l, int r, int p) { if(l <= tt[p].l && r >= tt[p].r) { return tt[p].m0; } pushdown(p); int mid = (tt[p].l + tt[p].r) >> 1; int ans = 0; if(l <= mid) ans += query(l, r, p << 1); if(r > mid) ans += query(l, r, p << 1 | 1); return ans; } int main() { int T; int cnt = 0; cin >> T; while(T--) { int n, q; scanf("%d%d", &n, &q); build(1, n, 1); printf("Case %d:\n", ++cnt); while(q--) { int t, x, y; scanf("%d%d%d", &t ,&x ,&y); if(t) printf("%d\n", query(x+1, y+1, 1)); else updata(x+1, y+1, 1, 1); } } return 0; }