【数据结构】线段树(字符串哈希)

namespace SegmentTree {

#define ls (o << 1)
#define rs (o << 1 | 1)

    ll bse[3] = {13, 1331, 2333};
    ll mod[3] = {998244353, 1000000007, 1000000009};
    ll bsk[200005][3];
    ll st[200005 << 2][3];

    void Init(int n) {
        for(int t = 0; t < 3; ++t) {
            bsk[0][t] = 1;
            for(int i = 1; i <= n; ++i)
                bsk[i][t] = (bsk[i - 1][t] * bse[t]) % mod[t];
        }
    }

    void PushUp(int o, int l, int r) {
        int m = (l + r) >> 1;
        for(int t = 0; t < 3; ++t)
            st[o][t] = (st[ls][t] * bsk[r - m][t] + st[rs][t]) % mod[t];
    }

    void Build(int o, int l, int r) {
        if(l == r) {
            for(int t = 0; t < 3; ++t)
                st[o][t] = s[l];
        } else {
            int m = (l + r) >> 1;
            Build(ls, l, m);
            Build(rs, m + 1, r);
            PushUp(o, l, r);
        }
    }

    void Update(int o, int l, int r, int p, char v) {
        if(l == r) {
            for(int t = 0; t < 3; ++t)
                st[o][t] = v;
        } else {
            int m = (l + r) >> 1;
            if(p <= m)
                Update(ls, l, m, p, v);
            if(p >= m + 1)
                Update(rs, m + 1, r, p, v);
            PushUp(o, l, r);
        }
    }

    ll Query(int o, int l, int r, int ql, int qr, int t) {
        if(ql <= l && r <= qr) {
            return st[o][t];
        } else {
            int m = (l + r) >> 1;
            ll res = 0;
            if(ql <= m)
                res = Query(ls, l, m, ql, qr, t);
            if(qr >= m + 1)
                res = (res * bsk[min(r, qr) - m][t] + Query(rs, m + 1, r, ql, qr, t)) % mod[t];
            return res;
        }
    }

#undef ls
#undef rs

};

区间赋值

struct SegmentTree {
 
#define ls (o << 1)
#define rs (o << 1 | 1)
 
    ll bse[3] = {13, 1331, 2333};
    ll mod[3] = {998244353, 1000000007, 1000000009};
    ll bsk[100005][3];
    ll hslen[100005][10][3];
    ll st[100005 << 2][3];
    char lz[100005 << 2];
 
    void Init(int n) {
        for(int t = 0; t < 3; ++t) {
            bsk[0][t] = 1;
            for(int i = 1; i <= n; ++i)
                bsk[i][t] = (bsk[i - 1][t] * bse[t]) % mod[t];
        }
        for(char c = '0'; c <= '9'; ++c) {
            for(int t = 0; t < 3; ++t) {
                hslen[0][c - '0'][t] = 0;
                for(int i = 1; i <= n; ++i)
                    hslen[i][c - '0'][t] = (hslen[i - 1][c - '0'][t] * bse[t] + c) % mod[t];
            }
        }
    }
 
    void PushUp(int o, int l, int r) {
        int m = (l + r) >> 1;
        for(int t = 0; t < 3; ++t)
            st[o][t] = (st[ls][t] * bsk[r - m][t] + st[rs][t]) % mod[t];
    }
 
    void PushDown(int o, int l, int r) {
        if(lz[o]) {
            int m = (l + r) >> 1;
            for(int t = 0; t < 3; ++t) {
                st[ls][t] = hslen[m - l + 1][lz[o] - '0'][t];
                st[rs][t] = hslen[r - m][lz[o] - '0'][t];
            }
            lz[ls] = lz[o];
            lz[rs] = lz[o];
            lz[o] = 0;
        }
    }
 
    void Build(int o, int l, int r) {
        lz[o] = 0;
        if(l == r) {
            for(int t = 0; t < 3; ++t)
                st[o][t] = s[l];
        } else {
            int m = (l + r) >> 1;
            Build(ls, l, m);
            Build(rs, m + 1, r);
            PushUp(o, l, r);
        }
    }
 
    void Update(int o, int l, int r, int ql, int qr, char v) {
        if(ql <= l && r <= qr) {
            lz[o] = v;
            for(int t = 0; t < 3; ++t)
                st[o][t] = hslen[(r - l + 1)][v - '0'][t];
        } else {
            PushDown(o, l, r);
            int m = (l + r) >> 1;
            if(ql <= m)
                Update(ls, l, m, ql, qr, v);
            if(qr >= m + 1)
                Update(rs, m + 1, r, ql, qr, v);
            PushUp(o, l, r);
        }
    }
 
    ll Query(int o, int l, int r, int ql, int qr, int t) {
        if(ql <= l && r <= qr) {
            return st[o][t];
        } else {
            PushDown(o, l, r);
            int m = (l + r) >> 1;
            ll res = 0;
            if(ql <= m)
                res = Query(ls, l, m, ql, qr, t);
            if(qr >= m + 1)
                res = (res * bsk[min(r, qr) - m][t] + Query(rs, m + 1, r, ql, qr, t)) % mod[t];
            return res;
        }
    }
 
#undef ls
#undef rs
 
} st;
posted @ 2020-11-26 22:34  purinliang  阅读(135)  评论(0编辑  收藏  举报