ZJOI2019 线段树

ZJOI2019 线段树

题意:

题目传送门

题解:

来讲一个非常卡常的矩阵+线段树做法。首先转化一下题意,直接将\(2^m\)棵线段树建出来一定是不现实的,我们对于每一个节点,记录这个节点在所有线段树中带有标记的次数,这样所有节点的权值之和就是答案了。
接下来考虑如何维护这个答案,由于每一个节点及其祖先的带标记情况只有四种:

  1. 当前节点带标记,祖先也带标记
  2. 当前节点不带标记,祖先带标记
  3. 当前节点带标记,祖先不带标记
  4. 当前节点不带标记,祖先也不带标记

实际上\(1.3\)两种状况是可以合起来的,因为可以发现之后转移当中,\(1.3\)的转移是相同的,所以我们可以将他们合起来(主要是合并起来会被评测的老年机子卡成\(40\)……)。所以我们当前记录\(f[i][0/1/2]\)表示\(i\)这个节点三种状态(分别对应上面四种情况的\(4, 2, 1 + 3\))的方案数。接下来考虑对于一个修改操作,我们的转移方法。我们假设当前节点为\(u\),其父亲节点为\(v\),修改区间为\([l, r]\)

  1. \([L_u, R_u] \subseteq [l, r]\)时,那么这个点所有祖先的标记都会下传到这个点上,并且这个点也会被打上标记,那么转移就是:

\[\begin{cases} f'[u][2] += f[u][0] + f[u][1] + f[u][2] \end{cases} \]

  1. \([L_u, R_u] \subset [l, r]\ \ \&\& \ \ [L_v, R_v] \subseteq [l, r]\),即当前节点属于修改区间,但是是在其祖先节点打上修改标记的这些节点。那么这个点所有情况都会被改成祖先带标记的情况,转移就是:

\[\begin{cases} f'[u][1] += f[u][0] + f[u][1] \\ f'[u][2] += f[u][2] \end{cases} \]

  1. \([L_u, R_u] \cap [l, r] \neq \emptyset\),即这个节点不会被打上修改标记,但是在定位修改区间的时候会被访问到的节点。这些节点及其祖先所有的标记都会下传,所有的情况都会被改成祖先与自身都不带标记的情况,转移就是:

\[\begin{cases} f'[u][0] += f[u][0] + f[u][1] + f[u][2] \\ \end{cases} \]

  1. \([L_u, R_u] \cap [l, r] = \emptyset \ \ \& \& \ \ [L_v, R_v] \cap [l, r] \neq \emptyset\),即这个节点不会在定位区间时被访问到,但是其父亲会被访问到,那么它的祖先如果有标记,就会下传到这个节点中,转移就是:

\[\begin{cases} f'[u][0] += f[u][0] \\ f'[u][2] += f[u][1] + f[u][2] \end{cases} \]

  1. \([L_u, R_u] \cap [l, r] = \emptyset \ \ \& \& \ \ [L_v, R_v] \cap [l, r] = \emptyset\),即这些节点与本次修改无关,直接把原来的方案数乘2即可,转移就是:

\[\begin{cases} f'[u][0] += f[u][0] \\ f'[u][1] += f[u][1] \\ f'[u][2] += f[u][2] \end{cases} \]

直接暴力转移复杂度就是\(O(n^2)\)的,我们发现\(1.3.4\)这三种转移在定位修改区间时都会被访问到,所以可以直接进行修改,然后\(2.5\)两个转移我们考虑打标记。考虑用矩阵进行维护,那么第二种转移的转移矩阵就是这样的:

\[\left[ \begin{matrix} 1 & 1 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{matrix} \right] \]

第三种转移的转移矩阵就是这样的:

\[\left[ \begin{matrix} 2 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{matrix} \right] \]

然后就是线段树打标记,查询根节点权值就行了。说实话……这个方法似乎用省选评测的老年机似乎是跑不过去的……算了我这个考场上根本没有码出来的菜鸡就不说了吧……

UPD:话说不用矩阵维护的方法似乎好写好调跑的又快啊……啊我真是菜爆了……

Code:

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 50;
const int Md = 998244353;
typedef long long ll;

inline int Add(const int &x, const int &y) { return (x + y >= Md) ? (x + y - Md) : (x + y); }
inline int Sub(const int &x, const int &y) { return (x - y < 0) ? (x - y + Md) : (x - y); }
inline int Mul(const int &x, const int &y) { return (ll)x * y % Md; }
int Powe(int x, int y) {
    int ans = 1;
    while (y) {
        if (y & 1)
            ans = Mul(ans, x);
        x = Mul(x, x);
        y >>= 1;
    }
    return ans;
}

int n, m;

namespace Solver2 {
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
    struct Mat {
        int v[3][3];
        Mat() { memset(v, 0, sizeof v); }
        int *operator[](int x) { return v[x]; }
        Mat operator*(Mat B) {
            Mat C;
    
            for (int k = 0; k < 3; k++) {
                for (int i = 0; i < 3; i++) {
                    if (!v[i][k])
                        continue;
    
                    if (B.v[k][0])
                        C.v[i][0] = Add(C.v[i][0], Mul(v[i][k], B.v[k][0]));
                    if (B.v[k][1])
                        C.v[i][1] = Add(C.v[i][1], Mul(v[i][k], B.v[k][1]));
                    if (B.v[k][2])
                        C.v[i][2] = Add(C.v[i][2], Mul(v[i][k], B.v[k][2]));
                }
            }
            return C;
        }
        void Base() {
            memset(v, 0, sizeof v);
            for (int i = 0; i < 3; i++) v[i][i] = 1;
        }
        int EquBase() {
            for (int i = 0; i < 3; i++) {
                for (int j = 0; j < 3; j++) {
                    if (i == j && v[i][j] != 1)
                        return 0;
                    if (i != j && v[i][j] != 0)
                        return 0;
                }
            }
            return 1;
        }
    };
    Mat Bas, trans1, trans2;
    
    Mat va[N << 2], tag[N << 2], sum[N << 2];
    
    void Apply(int o, Mat A) {
        sum[o] = sum[o] * A;
        tag[o] = tag[o] * A;
        va[o] = va[o] * A;
    }
    
    void Push(int o) {
        if (tag[o].EquBase()) return;
        Apply(ls(o), tag[o]);
        Apply(rs(o), tag[o]);
        tag[o].Base();
        return;
    }
    
    void Update(int o) {
        for (int i = 0; i < 3; i++) sum[o][0][i] = Add(sum[ls(o)][0][i], sum[rs(o)][0][i]), sum[o][0][i] = Add(sum[o][0][i], va[o][0][i]);
    }
    
    void Pre(int o, int l, int r) {
        va[o][0][0] = 1;
        tag[o].Base();
        sum[o][0][0] = 1;
        if (l == r) return;
        int mid = (l + r) >> 1;
        Pre(ls(o), l, mid);
        Pre(rs(o), mid + 1, r);
        Update(o);
    }
    
    void Modify(int o, int l, int r, int L, int R) {
        if (l > R || r < L) {
            va[o][0][0] = Add(va[o][0][0], va[o][0][0]);
            va[o][0][2] = Add(va[o][0][2], Add(va[o][0][1], va[o][0][2]));
            if (l == r) return (void)(Update(o));
            int mid = (l + r) >> 1;
            Push(o);
            Apply(ls(o), trans1);
            Apply(rs(o), trans1);
            Update(o);
            return;
        }
        if (L <= l && r <= R) {
            va[o][0][2] = Add(va[o][0][2], Add(va[o][0][2], Add(va[o][0][1], va[o][0][0])));
            if (l != r) {
                Push(o);
                va[ls(o)][0][1] = Add(va[ls(o)][0][1], Add(va[ls(o)][0][1], va[ls(o)][0][0]));
                va[ls(o)][0][2] = Add(va[ls(o)][0][2], va[ls(o)][0][2]);
                va[rs(o)][0][1] = Add(va[rs(o)][0][1], Add(va[rs(o)][0][1], va[rs(o)][0][0]));
                va[rs(o)][0][2] = Add(va[rs(o)][0][2], va[rs(o)][0][2]);
                int mid = (l + r) >> 1;
                if (l != mid) {
                    Push(ls(o));
                    Apply(ls(ls(o)), trans2);
                    Apply(rs(ls(o)), trans2);
                }
                if (r != mid) {
                    Push(rs(o));
                    Apply(ls(rs(o)), trans2);
                    Apply(rs(rs(o)), trans2);
                }
                Update(ls(o));
                Update(rs(o));
            }
            Update(o);
            return;
        }
        Push(o);
        int mid = (l + r) >> 1;
        va[o][0][0] = Add(va[o][0][0], Add(va[o][0][0], Add(va[o][0][1], va[o][0][2])));
        Modify(ls(o), l, mid, L, R);
        Modify(rs(o), mid + 1, r, L, R);
        Update(o);
        return;
    }
    
    void main() {
        Bas.Base();
        for (int i = 0; i < 3; i++) trans1[i][i] = 2;
        trans2[0][1] = trans2[0][0] = 1;
        trans2[1][1] = trans2[2][2] = 2;
        Pre(1, 1, n);
        for (int i = 1, tp; i <= m; i++) {
            scanf("%d", &tp);
            if (tp == 2)
                printf("%d\n", sum[1][0][2]);
            else {
                int l, r;
                scanf("%d%d", &l, &r);
                Modify(1, 1, n, l, r);
            }
        }
    }
}

int main() {
    scanf("%d%d", &n, &m);
    Solver2::main();
    return 0;
}
posted @ 2019-04-02 18:51  Apocrypha  阅读(382)  评论(0编辑  收藏  举报