线段树维护矩阵

对于一些只有单点修改且维护的信息具有递推性的题目,由于运算具有结合律,可以将两区间合并写成矩阵乘法的形式,省去一些麻烦的讨论。

前置知识:广义矩阵乘法

对于一个 n×m 的矩阵 A 和一个 m×t 的矩阵 B,定义广义矩阵乘法:

Ci,j=k=1mAi,kBk,j

满足分配律,即 (ab)c=(ac)(bc),则新定义的广义矩阵乘法具有结合律。

常见的例子有 =×,=+=+,=min/max 等。

广义矩阵乘法的单位矩阵主对角线上的元素是 的单位元,其他位置是 的单位元。

序列问题

例题:Luogu P6864 [RC-03] 记忆

题意:

有一个括号串 S,一开始 S 中只包含一对括号(即初始的 S()),接下来有 n 个操作,操作分为三种:

  1. 在当前 S 的末尾加一对括号(即 S 变为 S());

  2. 在当前 S 的最外面加一对括号(即 S 变为 (S));

  3. 取消第 x 个操作,即去除第 x 个操作造成过的一切影响(例如,如果第 x 个操作也是取消操作,且取消了第 y 个操作,那么当前操作的实质就是恢复了第 y 个操作的作用效果)。

每次操作后,你需要输出 S 的能够括号匹配的非空子串(子串要求连续)个数。

一个括号串能够括号匹配,当且仅当其左右括号数量相等,且任意一个前缀中左括号数量不少于右括号数量。

1n2×105

如果没有撤销,只需要维护合法子串数 ans 和 合法后缀数 suf,则操作 1 就是 ansans+suf+1,sufsuf+1,操作 2 就是 ansans+1,suf1

现在有了撤销,考虑用矩阵描述操作。

操作 1:

[anssuf1][anssuf1]×[100110111]

操作 2:

[anssuf1][anssuf1]×[100000111]

这里的矩阵乘法就是传统的矩阵乘法。

S 初始为 (),所以初始答案矩阵显然是 [111],答案就是将初始矩阵按顺序乘上若干个修改矩阵后,第一个元素的值。

现在考虑撤销:撤销一个修改就是把这个修改矩阵改成单位矩阵;撤销对一个修改的撤销就是把这个矩阵再改回对应的修改矩阵,以此类推……

我们发现有了矩阵这个神奇工具后撤销就变得平凡了,由于矩阵乘法具有结合律,所以可以用线段树维护,时间复杂度 O(nlogn)

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 2e5 + 5;
struct matrix {
    int a[3][3];
    matrix() { memset(a, 0, sizeof a); }
    matrix(vector<int> v) {
        rep(i, 0, 2) rep(j, 0, 2) a[i][j] = v[i * 3 + j];
    }
    int *operator[](int x) { return a[x]; }
    matrix operator*(matrix &b) const {
        matrix c;
        rep(i, 0, 2) rep(k, 0, 2) rep(j, 0, 2)
            c[i][j] += a[i][k] * b[k][j];
        return c;
    }
};
const matrix uni({1, 0, 0, 0, 1, 0, 0, 0, 1}), 
             op1({1, 0, 0, 1, 1, 0, 1, 1, 1}),
             op2({1, 0, 0, 0, 0, 0, 1, 1, 1});
matrix nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void build(int p, int l, int r) {
    if(l == r) return nd[p] = uni, void();
    int mid = (l + r) / 2;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    nd[p] = nd[ls] * nd[rs];
}
void modify(int p, int l, int r, int loc, const matrix &op) {
    if(l == r) return nd[p] = op, void();
    int mid = (l + r) / 2;
    if(loc <= mid) modify(ls, l, mid, loc, op);
    else modify(rs, mid + 1, r, loc, op);
    nd[p] = nd[ls] * nd[rs];
}
int n, p[N], op[N];
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n;
    build(1, 1, n);
    rep(i, 1, n) {
        cin >> op[i];
        if(op[i] == 1) modify(1, 1, n, i, op1), p[i] = i;
        else if(op[i] == 2) modify(1, 1, n, i, op2), p[i] = i;
        else {
            int x; cin >> x;
            if(p[x] > 0) modify(1, 1, n, p[x], uni), p[i] = -p[x];
            else if(op[-p[x]] == 1) modify(1, 1, n, -p[x], op1), p[i] = -p[x];
            else modify(1, 1, n, -p[x], op2), p[i] = -p[x];
        }
        cout << nd[1][0][0] + nd[1][1][0] + nd[1][2][0] << endl;
    }
    return 0;
}

树上问题

例题:Luogu P4719 【模板】"动态 DP"&动态树分治

题意:

给定一棵 n 个点的树,点带点权。

m 次操作,每次操作给定 x,y,表示修改点 x 的权值为 y

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

1n,m105,任意时刻点权 [100,100]

首先还是考虑没有修改的情况,也就是没有上司的舞会。取 1 号点为根,设 f(u,0) 表示只考虑 u 的子树时,不选择节点 u 的最大权值,f(u,1) 表示选择 u 的最大权值,则有:(其中 son(u) 表示 u 的所有儿子的集合,auu 的权值)

f(u,0)=vson(u)max(f(v,0),f(v,1))f(u,1)=au+vson(u)f(v,0)

答案就是 max(f(1,0),f(1,1))

现在带上修改,如果每次修改完都暴力重新 DP,时间复杂度 O(nq),TLE します。

我们发现,每次修改只会更改一条路径,这给了我们用树剖的自信。考虑结合上一题的思路,用树链剖分把树上问题转化为序列问题,然后用矩阵来描述转移,套上线段树求解。

然而我们尴尬地发现这个转移带一个 ,这导致我们无法用一个较小的矩阵直接描述转移。考虑利用树链剖分分出的“轻重儿子”概念,定义 g(u,0)g(u,1) 分别表示只考虑节点 u 和它的所有轻子树时,不选与选节点 u 的最大权值。那么就有:(这里的 vu 的重儿子)

g(u,0)=json(u)jvmax(f(j,0),f(j,1))g(u,1)=au+json(u)jvf(j,0)f(u,0)=g(u,0)+max(f(v,0),f(v,1))f(u,1)=g(u,1)+f(v,0)

于是我们消掉了 f 的转移上的 ,但是这仍然不能用传统矩阵乘法,于是我们定义一种广义的矩阵乘法,令 =max,=+,即 C=A×B 当且仅当 Ci,j=maxk=1m{Ai,k+Bk,j}m 表示矩阵大小)。注意到这个乘法满足结合律是因为加法对 max 满足分配率,即 max(a,b)+c=max(a+c,b+c),我们可以据此改写转移方程:

f(u,0)=max(f(v,0)+g(u,0),f(v,1)+g(u,0))f(u,1)=max(g(u,1)+f(v,0),)

于是我们可以构造转移矩阵:(这里用的是我们新定义的广义矩阵乘法)(状态矩阵的横竖会影响实现,事实上横竖都能做)

[f(u,0)f(u,1)]=[g(u,0)g(u,0)g(u,1)]×[f(v,0)f(v,1)]

我们在每个节点维护转移矩阵。接下来考虑修改,把节点 u 的权值改为 w

u 所在的这条重链上,因为其他节点的轻儿子都不包含 u,所以被更改的只有 g(u,1)。更改完这条链后,如果已经到达根(top(u)=1)就退出,否则根据 f(top(u)) 更改 g(fa(top(u))),向上递归。

如果要快速获得一个点 u 的 DP 值,只需要找到它所在的重链的底部节点 x,查询 ux 的路径上的总转移矩阵,与 v 点的状态矩阵相乘即可。因为 x 一定是叶子,所以它的的状态一定是 [0ax]

由于叶节点的转移没有意义,我们可以直接让叶节点的转移是 [0voidaxvoid],其中 void 没有意义,可以是任何值,这样就不用手动乘 [0ax] 了。

时间复杂度 O(nlog2n),用一个叫全局平衡二叉树的东西可以优化到 O(nlogn),但是我不会。

实现时要注意,更改时不能用全局定义的 f 数组,因为那个 f 不是实时更新的。

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 1e5 + 5, inf = 0x3f3f3f3f;
struct matrix {
    int a[2][2];
    int* operator[](int x) { return a[x]; }
    matrix operator*(matrix b) const {
        matrix c;
        c[0][0] = max(a[0][0] + b[0][0], a[0][1] + b[1][0]);
        c[0][1] = max(a[0][0] + b[0][1], a[0][1] + b[1][1]);
        c[1][0] = max(a[1][0] + b[0][0], a[1][1] + b[1][0]);
        c[1][1] = max(a[1][0] + b[0][1], a[1][1] + b[1][1]);
        return c;
    }
};
int n, m, a[N], f[N][2], g[N][2];
vector<int> to[N];
int fa[N], dep[N], sz[N], hs[N], tp[N], ed[N], dfn[N], rnk[N];
void dfs1(int u, int f) {
    fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
    for(int v : to[u]) if(v != f) {
        dfs1(v, u);
        sz[u] += sz[v];
        if(sz[v] > sz[hs[u]]) hs[u] = v;
    }
}
int dfs2(int u, int fa, int top) {
    static int ind;
    tp[u] = top, dfn[u] = ++ind, rnk[ind] = u;
    g[u][1] = a[u];
    if(hs[u]) ed[u] = dfs2(hs[u], u, top);
    else return f[u][1] = a[u], ed[u] = dfn[u];

    for(int v : to[u]) if(v != fa && v != hs[u]) {
        dfs2(v, u, v);
        g[u][0] += max(f[v][0], f[v][1]);
        g[u][1] += f[v][0];
    }
    f[u][0] = g[u][0] + max(f[hs[u]][0], f[hs[u]][1]);
    f[u][1] = g[u][1] + f[hs[u]][0];
    return ed[u];
}
matrix nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void assign(int p, int u) {
    if(!hs[u]) { 
        nd[p][1][0] = a[u];
    } 
    else { 
        nd[p][0][0] = nd[p][0][1] = g[u][0]; 
        nd[p][1][0] = g[u][1]; 
        nd[p][1][1] = -inf; 
    }
}
void build(int p, int l, int r) {
    if(l == r) return assign(p, rnk[l]);
    int mid = (l + r) / 2;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    nd[p] = nd[ls] * nd[rs];
}
void update(int p, int l, int r, int loc) {
    if(l == r) return assign(p, rnk[l]);
    int mid = (l + r) / 2;
    if(loc <= mid) update(ls, l, mid, loc);
    else update(rs, mid + 1, r, loc);
    nd[p] = nd[ls] * nd[rs];
}
matrix query(int p, int l, int r, int ql, int qr) {
    if(ql <= l && r <= qr) return nd[p];
    int mid = (l + r) / 2;
    if(qr <= mid) return query(ls, l, mid, ql, qr);
    if(ql > mid) return query(rs, mid + 1, r, ql, qr);
    return query(ls, l, mid, ql, qr) * query(rs, mid + 1, r, ql, qr);
}
void modify(int u, int w) {
    g[u][1] += w - a[u];
    a[u] = w;
    matrix o, p;
    while(1) {
        if(tp[u] != 1) o = query(1, 1, n, dfn[tp[u]], ed[u]);
        update(1, 1, n, dfn[u]);
        if(tp[u] == 1) break;
        p = query(1, 1, n, dfn[tp[u]], ed[u]);
        u = fa[tp[u]];
        g[u][0] += max(p[0][0], p[1][0]) - max(o[0][0], o[1][0]);
        g[u][1] += p[0][0] - o[0][0];
    }
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n - 1) {
        int u, v; cin >> u >> v;
        to[u].push_back(v);
        to[v].push_back(u);
    }
    dfs1(1, 0);
    dfs2(1, 0, 1);
    build(1, 1, n);
    while(m--) {
        int u, w; cin >> u >> w;
        modify(u, w);
        auto o = query(1, 1, n, 1, ed[1]);
        cout << max(o[0][0], o[1][0]) << endl;
    }
    return 0;
}

练习题:Luogu P8820 [CSP-S 2022] 数据传输

形式化题意:

给定一棵 n 个节点的树,和一个常数 k,树上每个节点 i 都有权值 vi

定义 dis(i,j) 表示树上 i,j 两点间简单路径的边数。

Q 次询问,每次询问给定两个节点 s,t(st),你需要找出一个长为 m 的序列 c,满足 i[1,m],ci[1,n]c1=s,cm=ti[1,m1],dis(ci,ci+1)k,使得 i=1mvci 最小,输出这个最小值。

1n,Q2×105,1vi109,1k3

其实就是从 s 开始,每次最多走 k 步,走到 t,最小化经过的的点权和。

注意到 k 很小,设 path(i,j) 表示 i,j 间简单路径上的点的集合,考虑分类讨论 k

  • k=1 时,就是简单路径点权和,容易实现
  • k=2 时,如果存在 i[1,m],cipath(s,t),则必然存在j[i+1,m],cjpath(s,t),因为原图是树且 k=2,所以 dis(ci,cj)2(这里可以画个图理解一下)所以一定可以直接从 ci1 走到 cj,又因为 v>0,所以不走 cicj1 一定比走这些点要优,所以 i[1,m],cipath(s,t),于是可以简单 DP f(i)=max(f(i1),f(i2))+vci,令广义矩阵乘法 =min,=+,容易构造转移矩阵

[f(i)f(i1)]=[vcivci0]×[f(i1)f(i2)]

  • k=3 时,沿用 k=2 时的证明不难发现,如果存在 cipath(s,t),则必然有 minjpath(s,t)dis(ci,j)=1,这种情况也是容易考虑的,设 f(i,j) 表示到达与 ci 的距离为 j 的点所需要的最小代价,设 mnu 表示所有与 u 相邻的点中最小的权值,可以构造转移矩阵:

[f(i,0)f(i,1)f(i,2)]=[vcivcivci0mnci0]×[f(i1,0)f(i1,1)f(i1,2)]

然后树剖线段树维护就可以了,不过由于这个东西是有顺序的,所以线段树节点上需要同时维护左乘右和右乘左两个矩阵。

时间复杂度 O(nlog2n),实现时需要注意树剖跳重链时矩阵乘法的顺序。左乘右和右乘左两种询问分开写可以减小常数。

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 2e5 + 5;
int n, q, k, a[N], mn[N];
vector<int> to[N];
struct matrix {
    ll a[3][3];
    matrix() { memset(a, 0x3f, sizeof a); }
    ll* operator[](int x) { return a[x]; }
    matrix operator*(matrix b) const {
        matrix c;
        rep(i, 0, 2) rep(j, 0, 2)
            c[i][j] = min({a[i][0] + b[0][j], a[i][1] + b[1][j], a[i][2] + b[2][j]});
        return c;
    }
};
int fa[N], dep[N], sz[N], hs[N], top[N], dfn[N], rnk[N];
void dfs1(int u, int f) {
    dep[u] = dep[f] + 1;
    fa[u] = f;
    sz[u] = 1;
    for(int v : to[u]) if(v != f) {
        dfs1(v, u);
        sz[u] += sz[v];
        if(sz[v] > sz[hs[u]]) hs[u] = v;
    }
}
void dfs2(int u, int f, int tp) {
    static int ind;
    top[u] = tp;
    dfn[u] = ++ind;
    rnk[ind] = u;
    if(hs[u]) dfs2(hs[u], u, tp);
    for(int v : to[u]) 
        if(v != f && v != hs[u]) 
            dfs2(v, u, v);
}
pair<matrix, matrix> nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void assign(int p, int u) {
    nd[p].F[1][0] = nd[p].F[2][1] = nd[p].S[1][0] = nd[p].S[2][1] = 0;
    nd[p].F[0][0] = nd[p].S[0][0] = a[u];
    if(k >= 2) nd[p].F[0][1] = nd[p].S[0][1] = a[u];
    if(k == 3) 
        nd[p].F[0][2] = nd[p].S[0][2] = a[u], 
        nd[p].F[1][1] = nd[p].S[1][1] = mn[u];
}
void build(int p, int l, int r) {
    if(l == r) return assign(p, rnk[l]);
    int mid = (l + r) / 2;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    nd[p].F = nd[ls].F * nd[rs].F;
    nd[p].S = nd[rs].S * nd[ls].S;
}
matrix qlr(int p, int l, int r, int ql, int qr) {
    if(ql <= l && r <= qr) return nd[p].F;
    int mid = (l + r) / 2;
    if(qr <= mid) return qlr(ls, l, mid, ql, qr);
    if(ql > mid) return qlr(rs, mid + 1, r, ql, qr);
    return qlr(ls, l, mid, ql, qr) * qlr(rs, mid + 1, r, ql, qr);
}
matrix qrl(int p, int l, int r, int ql, int qr) {
    if(ql <= l && r <= qr) return nd[p].S;
    int mid = (l + r) / 2;
    if(qr <= mid) return qrl(ls, l, mid, ql, qr);
    if(ql > mid) return qrl(rs, mid + 1, r, ql, qr);
    return qrl(rs, mid + 1, r, ql, qr) * qrl(ls, l, mid, ql, qr);
}
ll query(int u, int v) {
    matrix um, vm;
    rep(i, 0, 2) um[i][i] = vm[i][i] = 0;
    while(top[u] != top[v]) {
        if(dep[top[u]] >= dep[top[v]]) {
            um = qlr(1, 1, n, dfn[top[u]], dfn[u]) * um;
            u = fa[top[u]];
        }
        else {
            vm = vm * qrl(1, 1, n, dfn[top[v]], dfn[v]);
            v = fa[top[v]];
        }
    }
    matrix o;
    if(dep[u] >= dep[v]) o = vm * qlr(1, 1, n, dfn[v], dfn[u]) * um;
    else o = vm * qrl(1, 1, n, dfn[u], dfn[v]) * um;
    return o[0][k - 1];
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> q >> k;
    memset(mn, 0x3f, sizeof mn);
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n - 1) {
        int u, v; cin >> u >> v;
        to[u].push_back(v);
        to[v].push_back(u);
        gmin(mn[u], a[v]);
        gmin(mn[v], a[u]);
    }
    dfs1(1, 0);
    dfs2(1, 0, 1);
    build(1, 1, n);
    while(q--) {
        int s, t; cin >> s >> t;
        cout << query(s, t) << endl;
    }
    return 0;
}
posted @   untitled0  阅读(293)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示