Codeforces 719E (线段树教做人系列) 线段树维护矩阵
题面简洁明了,一看就懂
做了这个题之后,才知道怎么用线段树维护递推式。递推式的递推过程可以看作两个矩阵相乘,假设矩阵A是初始值矩阵,矩阵B是变换矩阵,求第n项相当于把矩阵B乘了n - 1次。
那么我们线段树中每个点维护把矩阵B乘了多少次,懒标记下放的时候用快速幂维护sum。
#include <bits/stdc++.h> #define LL long long #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) using namespace std; const LL mod = 1000000007; const int maxn = 100010; struct Matrix { static const int len = 2; LL x[len][len]; void init() { memset(x, 0, sizeof(x)); for (int i = 0; i < len; i++) x[i][i] = 1; } void zero() { memset(x, 0, sizeof(x)); } Matrix operator * (const Matrix& m) const { Matrix ans; ans.zero(); for (int i = 0; i < len; i++) for (int j = 0; j < len; j++) for (int k = 0; k < len; k++) ans.x[i][j] = (ans.x[i][j] + x[i][k] * m.x[k][j]) % mod; return ans; } Matrix operator + (const Matrix& m) const { Matrix ans; ans.zero(); for (int i = 0; i < len; i++) for (int j = 0; j < len; j++) ans.x[i][j] = (x[i][j] + m.x[i][j]) % mod; return ans; } Matrix operator ^ (int b) const { Matrix ans, a; ans.init(); memcpy(a.x, x, sizeof(x)); for (; b; b >>= 1) { if(b & 1) ans = ans * a; a = a * a; } return ans; } }; Matrix mul , tmp, trans ; int a[maxn]; struct SegementTree { int lz; Matrix sum, flag; }; SegementTree tr[maxn * 4]; void maintain(int o) { tr[o].sum = tr[ls(o)].sum + tr[rs(o)].sum; } void pushdown(int o) { if(tr[o].lz) { tr[ls(o)].sum = tr[ls(o)].sum * tr[o].flag; tr[rs(o)].sum = tr[rs(o)].sum * tr[o].flag; tr[ls(o)].flag = tr[ls(o)].flag * tr[o].flag; tr[rs(o)].flag = tr[rs(o)].flag * tr[o].flag; tr[o].lz = 0; tr[ls(o)].lz = 1; tr[rs(o)].lz = 1; tr[o].flag.init(); } } void build(int o, int l, int r) { tr[o].sum.zero(); tr[o].lz = 0; tr[o].flag.init(); if(l == r) { tr[o].sum = trans * ( mul ^ (a[l] - 1)); return; } int mid = (l + r) >> 1; build(ls(o), l, mid); build(rs(o), mid + 1, r); maintain(o); } void update(int o, int l, int r, int ql, int qr, Matrix now) { if(l >= ql && r <= qr) { tr[o].sum = tr[o].sum * now; tr[o].flag = tr[o].flag * now; tr[o].lz = 1; return; } pushdown(o); int mid = (l + r) >> 1; if(ql <= mid) update(ls(o), l, mid, ql, qr, now); if(qr > mid) update(rs(o), mid + 1, r, ql, qr, now); maintain(o); } LL query(int o, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[o].sum.x[0][1]; } pushdown(o); int mid = (l + r) >> 1; LL ans = 0; if(ql <= mid) ans = (ans + query(ls(o), l, mid, ql, qr)) % mod; if(qr > mid) ans = (ans + query(rs(o), mid + 1, r, ql, qr)) % mod; return ans; } int main() { int n, m, op, l, r; LL x; trans.zero(); trans.x[0][1] = 1; mul.x[0][1] = mul.x[1][0] = mul.x[1][1] = 1; mul.x[0][0] = 0; scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); } build(1, 1, n); for (int i = 1; i <= m; i++) { scanf("%d%d%d", &op, &l, &r); if(op == 1) { scanf("%lld", &x); tmp = (mul ^ x); update(1, 1, n, l, r, tmp); } else { printf("%lld\n", query(1, 1, n, l, r)); } } }