[BZOJ4785][P3688][ZJOI2017]树状数组[树套树+树状数组]
这份代码在某些corner case会出错
因为是\(\mod 2\)意义下进行的运算,所以可以用异或来代替(只是为了方便。。
\[\text{query}\left( l,r \right) =\text{find}\left( r \right) -\text{find}\left( l-1 \right) =\text{xor}\left( r,n \right) \ \text{xor\ xor}\left( l-\text{1,}n \right) =\text{xor}\left( l-\text{1,}r-1 \right)
\]
实际要求的是$\text{xor}\left( l,r \right) $
所以 答案错误当且仅当 \(\left( l-1 \right) \oplus r\) 为 \(1\)
所以每次询问就是询问 \(\left( l-1 \right) \oplus r\) 为 \(0\)的概率
对于一个add操作\([l_1,r_1]\),只会对以下两种区间产生影响
1.左右端点有一个在\([l_1,r_1]\)内,会有\(\frac{1}{r-l+1}\)的概率被取反
2.左右端点均在\([l_1,r_1]\)内,会有\(\frac{2}{r-l+1}\)的概率被取反
这样的话就涉及两维的修改,所以需要用二维线段树来做,线段树上的节点\((a_1,a_2)\)表示左端点是\(a_1\)右端点\(a_2\)时不出错的概率
修改的时候原有概率是\(x\),要合并的概率是\(y\)
\(x\ *\ \left( 1-y \right) \ +\ y\ *\ \left( 1-x \right)\)
这样就可以把两个概率合并起来
另外l=1的话要特判
这个tot不太懂有什么用,去掉bzoj仍然能过,可能是corner case.....以及++n的作用。。。求指点
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline void read(int &x) {
int c = getchar();
x = 0;
while(!isdigit(c)) c = getchar();
while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
}
const int MAXN = 400010, MAXM = 35000000, mod = 998244353;
template<typename T> T Pow(T a, T b) {
T ret = 1;
for( ; b; b >>= 1) {
if (b & 1) ret = ret * 1ull * a % mod;
a = a * 1ull * a % mod;
}
return ret%mod;
}
inline int merge(int a, int b) {
return ( 1ll * a * ( mod + 1 - b ) + 1ll * b * ( mod + 1 - a ) ) % mod;
}
int n, tot, Q, l, r, op, len, inv, v;
struct Node {
int val;
Node *ls, *rs;
} pool[MAXM];
struct node {
Node *rt;
node *ls, *rs;
} Pool[MAXN], *root;
inline Node *newNode() {
static int cnt = 0;
return &pool[cnt++];
}
inline node *newnode() {
static int cnt = 0;
return &Pool[cnt++];
}
node *build(int l, int r) {
node *cur = newnode();
if (l != r) {
int mid = l+r>>1;
cur->ls = build(l, mid);
cur->rs = build(mid+1, r);
}
return cur;
}
void add2(Node *&cur, int l, int r, int ql, int qr, int v) {
if (!cur) cur = newNode();
if (ql <= l && r <= qr) {
cur->val = merge(cur->val, v);
return ;
}
int mid = l+r>>1;
if (ql <= mid) add2(cur->ls, l, mid, ql, qr, v);
if (qr > mid) add2(cur->rs, mid+1, r, ql, qr, v);
}
void add1(node *cur, int l, int r, int ql, int qr, int nxtl, int nxtr, int v) {
if (ql <= l && r <= qr) return add2(cur->rt, 0, n, nxtl, nxtr, v);
int mid = l+r>>1;
if (ql <= mid) add1(cur->ls, l, mid, ql, qr, nxtl, nxtr, v);
if (qr > mid) add1(cur->rs, mid+1, r, ql, qr, nxtl, nxtr, v);
}
int query2(Node *cur, int l, int r, int p) {
if (!cur) return 0;
int ret = cur->val;
if (l == r) return ret;
int mid = l+r>>1;
if (p <= mid) return merge(query2(cur->ls, l, mid, p), ret);
return merge(query2(cur->rs, mid+1, r, p), ret);
}
int query1(node *cur, int l, int r, int p, int v) {
int ret = query2(cur->rt, 0, n, v);
if (l == r) return ret;
int mid = l+r>>1;
if (p <= mid) return merge(query1(cur->ls, l, mid, p, v), ret);
return merge(query1(cur->rs, mid+1, r, p, v), ret);
}
int main(void) {
read(n), read(Q), ++n;
root = build(0, n);
while(Q--) {
read(op), read(l), read(r), len = r-l+1, inv = Pow(len, mod-2);
if (op == 1) {
add1(root, 0, n, l, r, l, r, 1ll * 2 * inv % mod);
add1(root, 0, n, 0, l-1, l, r, inv);
add1(root, 0, n, l, r, r+1, n, inv);
tot ^= 1;
} else {
if (l == 1 && tot) printf("%d\n", query1(root, 0, n, 0, r));
else printf("%d\n", (1 + mod - query1(root, 0, n, l-1, r)) % mod);
}
}
return 0;
}