洛谷P3328(bzoj 4085)毒瘤线段树
题面及大致思路:https://www.cnblogs.com/Yangrui-Blog/p/9623294.html, https://www.cnblogs.com/New-Godess/p/4567282.html
每个点维护2个矩阵,一共15个变量。矩阵a: [a(i - 1), a(i), a(i + 1); b(i - 1), b(i), b(i + 1)], 矩阵b就是a(i - 1), a(i), a(i + 1)与b(i - 1), b(i), b(i + 1)的两两乘积,矩阵转移的过程很显然,就不细说了。这个题的思维难度不高,就是两点很烦人:1 卡常 2 维护变量很麻烦。
代码:
#include <bits/stdc++.h> #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) using namespace std; const int mod = 1000000007; const int maxn = 300010; int c[maxn][4]; int a, b, inv; int qpow(int x, int y) { int ans = 1; for(; y; y >>= 1) { if(y & 1) ans = 1ll * ans * x % mod; x = 1ll * x * x % mod; } return ans; } struct Matrix { int a[3][3], n, m; void init(int x) { memset(a, 0, sizeof(a)); n = m = x; for (int i = 0; i < n; i++) a[i][i] = 1; } Matrix operator * (const Matrix &rhs) const { Matrix ret; memset(ret.a, 0, sizeof(ret.a)); ret.n = n, ret.m = rhs.m; for (int i = 0; i < n; i++) for (int k = 0; k < m; k++) for (int j = 0; j < 2; j++) ret.a[i][j] = (ret.a[i][j] + 1ll * a[i][k] * rhs.a[k][j] % mod) % mod; ret.a[2][2] = 1; return ret; } }; Matrix A, B, p[35], E; Matrix qpow(int k) { Matrix ans = E; for (int i = 1; k; i++, k >>= 1) { if(k & 1) ans = ans * p[i]; } return ans; } void get_Matrix(int pos) { Matrix tmp = A * qpow(c[pos][0] - 2); c[pos][1] = tmp.a[0][1], c[pos][2] = tmp.a[0][0]; } void init_p() { A.n = 1, A.m = 3; A.a[0][0] = 2, A.a[0][1] = 1, A.a[0][2] = 1; B.n = B.m = 3; B.a[0][0] = 1; B.a[1][0] = a, B.a[2][0] = b; B.a[0][1] = 1; B.a[2][2] = 1; for (int i = 1; i <= 32; i++) { p[i] = B; B = B * B; } } struct SegementTree { int sum[2][3], val[3][3], l, r, lz[2]; }; SegementTree tr[maxn * 4]; inline void pushup(int now) { for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++) tr[now].val[i][j] = (tr[ls(now)].val[i][j] + tr[rs(now)].val[i][j]) % mod; for (int i = 0; i < 2; i++) for (int j = 0; j < 3; j++) tr[now].sum[i][j] = (tr[ls(now)].sum[i][j] + tr[rs(now)].sum[i][j]) % mod; } inline void add1(int now, int flag) { int l = tr[now].l, r = tr[now].r; for (int i = 0; i < 2; i++) tr[now].sum[flag][i] = tr[now].sum[flag][i + 1]; tr[now].sum[flag][2] = (tr[now].sum[flag][1] + 1ll * tr[now].sum[flag][0] * a + 1ll * b * (r - l + 1)) % mod; if(flag == 0) { for (int i = 0; i < 2; i++) for (int j = 0 ; j < 3; j++) tr[now].val[i][j] = tr[now].val[i + 1][j]; for (int j = 0; j < 3; j++) tr[now].val[2][j] = (tr[now].val[1][j] + 1ll * tr[now].val[0][j] * a + 1ll * b * tr[now].sum[1][j]) % mod; } else { for (int i = 0; i < 3; i++) for (int j = 0; j < 2; j++) tr[now].val[i][j] = tr[now].val[i][j + 1]; for (int i = 0; i < 3; i++) tr[now].val[i][2] = (tr[now].val[i][1] + 1ll * tr[now].val[i][0] * a + 1ll * b * tr[now].sum[0][i]) % mod; } } inline void dec1(int now, int flag) { int l = tr[now].l, r = tr[now].r; if(a == 0) { for (int i = 1; i >= 0; i--) tr[now].sum[flag][i + 1] = tr[now].sum[flag][i]; tr[now].sum[flag][0] = (tr[now].sum[flag][1] - 1ll * b * (r - l + 1) % mod + mod) % mod; if(flag == 0) { for (int i = 1; i >= 0; i--) for (int j = 0; j < 3; j++) tr[now].val[i + 1][j] = tr[now].val[i][j]; for (int i = 0; i < 3; i++) tr[now].val[0][i] = (tr[now].val[1][i] - 1ll * b * tr[now].sum[1][i] % mod + mod) % mod; } else { for (int i = 0; i < 3; i++) for (int j = 1; j >= 0; j--) tr[now].val[i][j + 1] = tr[now].val[i][j]; for (int i = 0; i < 3; i++) tr[now].val[i][0] = (tr[now].val[i][1] - 1ll * b * tr[now].sum[0][i] % mod + mod) % mod; } return; } for (int i = 1; i >= 0; i--) tr[now].sum[flag][i + 1] = tr[now].sum[flag][i]; tr[now].sum[flag][0] = ((tr[now].sum[flag][2] - tr[now].sum[flag][1] - 1ll * b * (r - l + 1) % mod) * inv % mod + mod) % mod; if(flag == 0) { for (int i = 1; i >= 0; i--) for (int j = 0; j < 3; j++) tr[now].val[i + 1][j] = tr[now].val[i][j]; for (int i = 0; i < 3; i++) tr[now].val[0][i] = ((tr[now].val[2][i] - tr[now].val[1][i] - 1ll * b * tr[now].sum[1][i] % mod) % mod * inv % mod + mod) % mod; } else { for (int i = 0; i < 3; i++) for (int j = 1; j >= 0; j--) tr[now].val[i][j + 1] = tr[now].val[i][j]; for (int i = 0; i < 3; i++) tr[now].val[i][0] = ((tr[now].val[i][2] - tr[now].val[i][1] - 1ll * b * tr[now].sum[0][i] % mod) % mod * inv % mod + mod) % mod; } } inline void pushdown(int now, int flag, int y) { if(y > 0) for (int i = 1; i <= y; i++) add1(now, flag); else for (int i = -1; i >= y; i--) dec1(now, flag); } inline void Pushdown(int now) { for (int flag = 0; flag < 2; flag++) { pushdown(ls(now), flag, tr[now].lz[flag]); pushdown(rs(now), flag, tr[now].lz[flag]); tr[ls(now)].lz[flag] += tr[now].lz[flag]; tr[rs(now)].lz[flag] += tr[now].lz[flag]; tr[now].lz[flag] = 0; } } inline void build(int now, int l, int r) { tr[now].l = l, tr[now].r = r; if(l == r) { for (int i = 0; i < 3; i++) { tr[now].sum[0][i] = c[l - 1][i + 1]; tr[now].sum[1][i] = c[l + 1][i + 1]; } for (int i = 0; i < 3; i++) for (int j = 0;j < 3; j++) tr[now].val[i][j] = 1ll * tr[now].sum[0][i] * tr[now].sum[1][j] % mod; return; } int mid = (l + r) >> 1; build(ls(now), l, mid); build(rs(now), mid + 1, r); pushup(now); } inline void update(int now, int ql, int qr, int flag,int val) { int l = tr[now].l, r = tr[now].r; if(l > qr || r < ql) return; if(l >= ql && r <= qr) { tr[now].lz[flag] += val; pushdown(now, flag, val); return; } Pushdown(now); int mid = (l + r) >> 1; if(ql <= mid) update(ls(now), ql, qr, flag, val); if(qr > mid) update(rs(now), ql, qr, flag, val); pushup(now); } inline int query(int now, int ql, int qr) { int l = tr[now].l, r = tr[now].r; if(l > qr || r < ql) return 0; if(l >= ql && r <= qr) return tr[now].val[2][0]; Pushdown(now); int mid = (l + r) >> 1; int ans = 0; if(ql <= mid) ans = (ans + query(ls(now), ql, qr)) % mod; if(qr > mid) ans = (ans + query(rs(now), ql, qr)) % mod; return ans; } int main() { int n, m; scanf("%d%d%d%d", &n, &m, &a, &b); inv = qpow(a, mod - 2); E.init(3); init_p(); for (int i = 1; i <= n; i++) { scanf("%d", &c[i][0]); get_Matrix(i); c[i][3] = (c[i][2] + 1ll * a * c[i][1] % mod + b) % mod; } char op[10]; build(1, 2, n - 1); while(m--) { scanf("%s", op + 1); int x, y; scanf("%d%d", &x, &y); if(op[1] == 'p') { update(1, x + 1, y + 1, 0, 1); update(1, x - 1, y - 1, 1, 1); } else if (op[1] == 'm') { update(1, x + 1, y + 1, 0, -1); update(1, x - 1, y - 1, 1, -1); } else { printf("%d\n", query(1, x + 1, y - 1)); } } }