[BZOJ4785][P3688][ZJOI2017]树状数组[树套树+树状数组]

这份代码在某些corner case会出错

因为是\(\mod 2\)意义下进行的运算,所以可以用异或来代替(只是为了方便。。

\[\text{query}\left( l,r \right) =\text{find}\left( r \right) -\text{find}\left( l-1 \right) =\text{xor}\left( r,n \right) \ \text{xor\ xor}\left( l-\text{1,}n \right) =\text{xor}\left( l-\text{1,}r-1 \right) \]

实际要求的是$\text{xor}\left( l,r \right) $

所以 答案错误当且仅当 \(\left( l-1 \right) \oplus r\)\(1\)

所以每次询问就是询问 \(\left( l-1 \right) \oplus r\)\(0\)的概率

对于一个add操作\([l_1,r_1]\),只会对以下两种区间产生影响

1.左右端点有一个在\([l_1,r_1]\)内,会有\(\frac{1}{r-l+1}\)的概率被取反

2.左右端点均在\([l_1,r_1]\)内,会有\(\frac{2}{r-l+1}\)的概率被取反

这样的话就涉及两维的修改,所以需要用二维线段树来做,线段树上的节点\((a_1,a_2)\)表示左端点是\(a_1\)右端点\(a_2\)时不出错的概率

修改的时候原有概率是\(x\),要合并的概率是\(y\)
\(x\ *\ \left( 1-y \right) \ +\ y\ *\ \left( 1-x \right)\)

这样就可以把两个概率合并起来

另外l=1的话要特判

这个tot不太懂有什么用,去掉bzoj仍然能过,可能是corner case.....以及++n的作用。。。求指点

#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
 
inline void read(int &x) {
    int c = getchar();
    x = 0;
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
}
const int MAXN = 400010, MAXM = 35000000, mod = 998244353;
template<typename T> T Pow(T a, T b) {
    T ret = 1;
    for( ; b; b >>= 1) {
        if (b & 1) ret = ret * 1ull * a % mod;
        a = a * 1ull * a % mod;
    }
    return ret%mod;
}
 
inline int merge(int a, int b) {
    return ( 1ll * a * ( mod + 1 - b ) + 1ll * b * ( mod + 1 - a ) ) % mod;
}
int n, tot, Q, l, r, op, len, inv, v;
 
struct Node {
    int val;
    Node *ls, *rs;
} pool[MAXM];
struct node {
    Node *rt;
    node *ls, *rs;
} Pool[MAXN], *root;
inline Node *newNode() {
    static int cnt = 0;
    return &pool[cnt++];
}
inline node *newnode() {
    static int cnt = 0;
    return &Pool[cnt++];
}
node *build(int l, int r) {
    node *cur = newnode();
    if (l != r) {
        int mid = l+r>>1;
        cur->ls = build(l, mid);
        cur->rs = build(mid+1, r);
    }
    return cur;
}
void add2(Node *&cur, int l, int r, int ql, int qr, int v) {
    if (!cur) cur = newNode();
    if (ql <= l && r <= qr) {
        cur->val = merge(cur->val, v);
        return ;
    }
    int mid = l+r>>1;
    if (ql <= mid) add2(cur->ls, l, mid, ql, qr, v);
    if (qr > mid) add2(cur->rs, mid+1, r, ql, qr, v);
}
 
void add1(node *cur, int l, int r, int ql, int qr, int nxtl, int nxtr, int v) {
    if (ql <= l && r <= qr) return add2(cur->rt, 0, n, nxtl, nxtr, v);
    int mid = l+r>>1;
    if (ql <= mid) add1(cur->ls, l, mid, ql, qr, nxtl, nxtr, v);
    if (qr > mid) add1(cur->rs, mid+1, r, ql, qr, nxtl, nxtr, v);
}
 
int query2(Node *cur, int l, int r, int p) {
    if (!cur) return 0;
    int ret = cur->val;
    if (l == r) return ret;
    int mid = l+r>>1;
    if (p <= mid) return merge(query2(cur->ls, l, mid, p), ret);
    return merge(query2(cur->rs, mid+1, r, p), ret);
}
 
int query1(node *cur, int l, int r, int p, int v) {
    int ret = query2(cur->rt, 0, n, v);
    if (l == r) return ret;
    int mid = l+r>>1; 
    if (p <= mid) return merge(query1(cur->ls, l, mid, p, v), ret);
    return merge(query1(cur->rs, mid+1, r, p, v), ret);
}
 
int main(void) {
    read(n), read(Q), ++n;
    root = build(0, n);
    while(Q--) {
        read(op), read(l), read(r), len = r-l+1, inv = Pow(len, mod-2);
        if (op == 1) {
            add1(root, 0, n, l, r, l, r, 1ll * 2 * inv % mod);
            add1(root, 0, n, 0, l-1, l, r, inv);
            add1(root, 0, n, l, r, r+1, n, inv);
            tot ^= 1;
        } else {
            if (l == 1 && tot) printf("%d\n", query1(root, 0, n, 0, r));
            else printf("%d\n", (1 + mod - query1(root, 0, n, l-1, r)) % mod);
        }
    }
    return 0;
}
posted @ 2018-12-28 16:07  QvvQ  阅读(212)  评论(0编辑  收藏  举报