P7453 [THUSCH2017] 大魔法师 题解

Description

Luogu传送门

Solution

标签里就俩东西,没错就是他们了。

观察到 4,5,6 操作中需要用到一个常数 \(v\),所以我们的矩阵得开到 1 * 4,存 \((A, B, C, 1)\)

至于转移的 6 个矩阵这里就不推了,比较基础(不会真的有人来做这道题了连矩阵都不会推吧)。

简单说下线段树上矩阵乘法的含义。

线段树上的每个节点都维护一个矩阵。

叶子节点就是每个矩阵本身。

其他的点就是所表示区间内的所有矩阵的对应位的和。

  • 对于更新操作:

    跟普通的线段树一样,就是把一个数乘上一个数改成了一个矩阵乘上一个矩阵,\(lazy\) 标记什么的也都一样更新。

    注意:矩阵乘法没有交换律,可别瞎乘了。

  • 对于询问:

    也跟普通的线段树是一样的。真的一点区别都没有。

但是本蒟蒻在一开始写的时候觉得有一点点无从下手,这里记录一下我的思考历程。

Q:这 6 个操作怎么用矩阵乘法维护啊?

A:这个简单,随便推推矩阵就好了,矩阵中是应该是 v 的地方先赋成 0,需要用的时候再给他赋上 v

Q:那这 6 个矩阵我怎么存下来呢,执行操作的时候我该怎么调用数组呢?

A:不要怕麻烦!直接开 6 个 4 * 4 的数组就完了!调用的时候判断该用哪个用哪个。

Q:我线段树 build 的时候初值怎么赋呢?

A:初值就是输入的东西,对于每一个叶子节点把输入的数直接扔进去就好了。

Q:我需要用到线段树的哪些操作呢?

A:在线段树上矩阵乘法一般都只需要区间乘法,这题也不例外,所以我们只需要维护一个乘法的懒惰标记。在 \(update\) 函数中,直接传一个矩阵变量即可。

再讲讲卡常:

  1. 最为重要: 矩阵乘法的最内层循环使用循环展开,加法取模改成减法。
  2. 使用快读。
  3. 不要开 \(long \ long\),乘法的时候强制类型转换(不会真的有人全开 \(long \ long\) 了吧)。
  4. 如果你不怕麻烦的话,不要使用重载的乘法运算符,因为那样是 4 * 4 的两个矩阵相乘,而你只需要 1 * 4 \(\times\) 4 * 4

其实只要有第一条和第三条基本上就能卡过去了。

下面就是你们最想看到的代码了 \(:)\)

Code

#include <bits/stdc++.h>

using namespace std;

namespace IO{
    inline int read(){
        int x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }

    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;

const int N = 3e5 + 10;
const int mod = 998244353;
int n, m;

inline int add(int x) {return x > mod ? x - mod : x;}

struct matrix{
    int num[4][4];

    matrix() {memset(num, 0, sizeof(num));}
    inline void clear() {memset(num, 0, sizeof(num));}
    inline void init() {num[0][0] = num[1][1] = num[2][2] = num[3][3] = 1;}

    matrix(const int a[][4]){
        for(int i = 0; i < 4; ++i)
            for(int j = 0; j < 4; ++j)
                num[i][j] = a[i][j];
    }

    matrix operator + (const matrix &b) const{
        matrix r;
        for(int i = 0; i < 4; ++i)
            for(int j = 0; j < 4; ++j)
                r.num[i][j] = (1ll * num[i][j] + b.num[i][j]) % mod;
        return r;
    }

    matrix operator * (const matrix &b) const{
        matrix r;
        for(int i = 0; i < 4; ++i)
            for(int j = 0; j < 4; ++j){
                    r.num[i][j] = add(r.num[i][j] + (1ll * num[i][0] * b.num[0][j]) % mod);
                    r.num[i][j] = add(r.num[i][j] + (1ll * num[i][1] * b.num[1][j]) % mod);
                    r.num[i][j] = add(r.num[i][j] + (1ll * num[i][2] * b.num[2][j]) % mod);
                    r.num[i][j] = add(r.num[i][j] + (1ll * num[i][3] * b.num[3][j]) % mod);
                }
        return r;
    }
}a[N], A[4], B[4];

#define ls rt << 1
#define rs rt << 1 | 1

matrix sum[N << 2], lazy[N << 2];

inline void pushup(int rt){
    sum[rt] = sum[ls] + sum[rs];
}

inline void pushdown(int rt){
    sum[ls] = sum[ls] * lazy[rt];
    sum[rs] = sum[rs] * lazy[rt];
    lazy[ls] = lazy[ls] * lazy[rt];
    lazy[rs] = lazy[rs] * lazy[rt];
    lazy[rt].clear(), lazy[rt].init();
}

inline void build(int l, int r, int rt){
    lazy[rt].init();
    if(l == r){
        sum[rt] = a[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, ls);
    build(mid + 1, r, rs);
    pushup(rt);
}

inline void update(int L, int R, matrix k, int l, int r, int rt){
    if(L <= l && r <= R){
        sum[rt] = sum[rt] * k;
        lazy[rt] = lazy[rt] * k;
        return;
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    if(L <= mid) update(L, R, k, l, mid, ls);
    if(R > mid) update(L, R, k, mid + 1, r, rs);
    pushup(rt);
}

inline matrix query(int L, int R, int l, int r, int rt){
    if(L <= l && r <= R) return sum[rt];
    pushdown(rt);
    int mid = (l + r) >> 1;
    matrix res;
    if(L <= mid) res = res + query(L, R, l, mid, ls);
    if(R > mid) res = res + query(L, R, mid + 1, r, rs);
    return res;
}

int main(){
    A[0].init(), A[1].init(), A[2].init(), B[0].init(), B[1].init(), B[2].init();
    A[0].num[1][0] = A[1].num[2][1] = A[2].num[0][2] = 1;
    B[2].num[2][2] = 0;
    n = read();
    for(int i = 1; i <= n; ++i)
        a[i].num[0][0] = read(), a[i].num[0][1] = read(), a[i].num[0][2] = read(), a[i].num[0][3] = 1;
    build(1, n, 1);
    m = read();
    while(m--){
        int op = read(), l = read(), r = read();
        if(op <= 3) update(l, r, A[op - 1], 1, n, 1);
        else if(op <= 6){
            B[0].num[3][0] = B[1].num[1][1] = B[2].num[3][2] = read();
            update(l, r, B[op - 4], 1, n, 1);
        }else{
            matrix ans = query(l, r, 1, n, 1);
            printf("%d %d %d\n", ans.num[0][0], ans.num[0][1], ans.num[0][2]);
        }
    }
    return 0;
}

\[\_EOF\_ \]

posted @ 2021-12-20 19:20  xixike  阅读(156)  评论(6编辑  收藏  举报