【线段树 矩阵乘法dp】8.rseq
题目分析
#include<bits/stdc++.h> #define MO 998244353 const int maxn = 200035; struct Matrix { int a[3][3]; void init(int c, int spe) { // printf("spe:%d\n",spe); a[0][0] = c, a[1][0] = 0, a[2][0] = 1ll*spe*c%MO; a[0][1] = 1, a[1][1] = 2, a[2][1] = 0; a[0][2] = 0, a[1][2] = 0, a[2][2] = 1; } }f[maxn<<2]; int n,m; int read() { char ch = getchar(); int num = 0, fl = 1; for (; !isdigit(ch); ch=getchar()) if (ch=='-') fl = -1; for (; isdigit(ch); ch=getchar()) num = (num<<1)+(num<<3)+ch-48; return num*fl; } int qmi(int a, int b) { if (b <= -1) return 1; int ret = 1; for (; b; b>>=1,a=1ll*a*a%MO) if (b&1) ret = 1ll*ret*a%MO; return ret; } void debug(Matrix t) { puts("------------------------------------"); for (int i=0; i<3; i++, puts("")) for (int j=0; j<3; j++) printf("%d ",t.a[i][j]); puts("------------------------------------"); } Matrix mult(Matrix a, Matrix b) { Matrix ret; ret.a[0][0] = 0, ret.a[1][0] = 0, ret.a[2][0] = 0; ret.a[0][1] = 0, ret.a[1][1] = 0, ret.a[2][1] = 0; ret.a[0][2] = 0, ret.a[1][2] = 0, ret.a[2][2] = 0; // debug(a); // debug(b); for (int k=0; k<3; k++) for (int i=0; i<3; i++) for (int j=0; j<3; j++) ret.a[i][j] = (ret.a[i][j]+1ll*a.a[i][k]*b.a[k][j]%MO)%MO; // debug(ret); // puts("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"); return ret; } void pushup(int rt) { f[rt] = mult(f[rt<<1], f[rt<<1|1]); } void build(int rt, int l, int r) { if (l==r) f[rt].init(read(), qmi(2, l-2)); else{ int mid = (l+r)>>1; build(rt<<1, l, mid); build(rt<<1|1, mid+1, r); pushup(rt); } } void modify(int rt, int l, int r, int c, int w) { if (l==r) f[rt].init(w, qmi(2, l-2)); else{ int mid = (l+r)>>1; if (c <= mid) modify(rt<<1, l, mid, c, w); else modify(rt<<1|1, mid+1, r, c, w); pushup(rt); } } void calc() { Matrix ans; ans.a[0][0] = 0, ans.a[1][0] = 0, ans.a[2][0] = 0; ans.a[0][1] = 0, ans.a[1][1] = 0, ans.a[2][1] = 0; ans.a[0][2] = 1, ans.a[1][2] = 0, ans.a[2][2] = 0; ans = mult(ans, f[1]); printf("%d\n",(ans.a[0][0]+ans.a[0][1])%MO); } int main() { freopen("rseq.in","r",stdin); freopen("rseq.out","w",stdout); n = read(), m = read(); build(1, 1, n), calc(); for (int i=1; i<=m; i++) { int pos = read(), val = read(); modify(1, 1, n, pos, val), calc(); } return 0; }