【题解】 「NOI2017」整数 线段树+二分+压位 LOJ2302
Legend
请维护一个高精二进制数 \(s\),支持操作 \(n\ (0 \le n \le 10^6)\) 次:
- 加或减 \(a\times 2^{b}\)。\((|a|\le 10^9,0\le b\le 30n)\)。
- 查询 \(s \operatorname{and} 2^b\) 的结果转化为 \(\textrm{bool}\) 后是否为真。
时空 \(\textrm{2s/512MB}\)。
Editorial
作为 \(\textrm{NOI2017}\) 的第一题,必定是一道良心送温暖题,让我们一起为出题人松松松鼓掌。
brute
容易看到底下有一些部分分,映入眼帘的便是 \(|a|=1\),于是我们就想到把每一个加减操作看成 \(O(\log a)\) 次加减单个二进制位。怎么样?是不是看起来简单一点了?
考虑直接模拟。假设现在做加法的是位置 \(l\),如果这一位是 \(0\) 就直接改成 \(1\),否则即找到之后第一个为 \(0\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(0\),并把位置 \(p\) 改成 \(1\)。
减法同理。如果这一位是 \(1\) 就直接改成 \(0\),否则即找到之后第一个为 \(1\) 的位置 \(p\ (l < p)\),把 \([l,p-1]\) 都改成 \(1\),并把位置 \(p\) 改成 \(0\)。
以上两个操作可以直接在线段树上二分找到,打区间覆盖标记。查询则可以直接使用线段树单点查询。
于是就得到了一个复杂度为 \(O(n \log n\log a)\) 的做法。
optimization
上述做法的瓶颈在于:
- 数组长度是 \(30n\),凭空多出来一个常数。
- 要进行拆位,\(1\) 个操作变成了 \(\log a\) 个。
不妨往反方向考虑,把数组压位,连续 \(32\) 个数字用一个 \(\textrm{unsigned int}\) 存储。
这样子对于一个修改操作我们最多只要拆成两个。而查询连续 \(1\) 段和连续 \(0\) 段依然可以用线段树实现,代码相差无几。
但这样就可以把复杂度优化到 \(O\left(\dfrac{n \log n \log a}{\omega}\right)\),其中 \(\omega\) 为压位大小。
Code
写的时候有点犯迷糊,最开始用 \(\textrm{unsigned int}\) 存了读入的 \(a\),后来又没写线段树的 \(\textrm{pushup pushdown}\),最后发现线段树二分写错了……白白浪费了一个下午+晚上。
就这样修修补补写出了下面这些东西,有点繁琐了,但还可以看。
LOJ 上这破烂可以在 \(\textrm{800ms}\) 内跑过。
// Author : Imakf
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define debug(...) fprintf(stderr ,__VA_ARGS__)
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout)
LL read(){
char k = getchar(); LL x = 0 ,flg = 1;
while(k < '0' || k > '9')
flg *= k == '-' ? -1 : 1 ,k = getchar();
while(k >= '0' && k <= '9')
x = x * 10 + k - '0' ,k = getchar();
return x * flg;
}
const int MX = 1e6 + 233;
struct node{
int l ,r ,c;
unsigned int num;
bool zero ,all ,cov;
node *lch ,*rch;
}*root;
void pushup(node *x){
x->zero = x->lch->zero & x->rch->zero;
x->all = x->lch->all & x->rch->all;
}
node *build(int l ,int r){
node *x = new node;
x->l = l;
x->r = r;
x->zero = true;
x->all = false;
x->cov = false;
x->c = 0;
x->num = 0;
if(l == r){
x->lch = nullptr;
x->rch = nullptr;
}
else{
int mid = (l + r) >> 1;
x->lch = build(l ,mid);
x->rch = build(mid + 1 ,r);
pushup(x);
}return x;
}
void docov(node *x ,bool v){
x->cov = true;
x->c = v;
x->zero = !v;
x->all = v;
x->num = v ? UINT_MAX : 0;
}
void pushdown(node *x){
if(x->cov){
x->cov = false;
docov(x->lch ,x->c);
docov(x->rch ,x->c);
}
}
void cov(node *x ,int l ,int r ,bool val){
if(l <= x->l && x->r <= r) return docov(x ,val);
pushdown(x);
if(l <= x->lch->r) cov(x->lch ,l ,r ,val);
if(r > x->lch->r) cov(x->rch ,l ,r ,val);
return pushup(x);
}
void add(node *x ,LL v){
x->num += v;
x->all = x->num == UINT_MAX;
x->zero = x->num == 0;
}
int __add(node *x ,int l ,int r){ // 找到最小的不是全 1 的 pos
if(x->r < l || x->l > r) return 0;
if(x->all) return 0;
if(x->l == x->r){
return add(x ,1) ,x->l;
}
pushdown(x);
int ret = 0;
if(x->lch->all || !(ret = __add(x->lch ,l ,r))){
ret = __add(x->rch ,l ,r);
}
pushup(x);
return ret;
}
void add(node *x ,int p ,LL val){
if(x->l == x->r){
if(x->num + val > UINT_MAX){
x->num = (x->num + val) & UINT_MAX;
add(x ,0);
int pos = __add(root ,p + 1 ,MX);
if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,0);
}
else add(x ,val);
return ;
}
pushdown(x);
if(p <= x->lch->r) add(x->lch ,p ,val);
else add(x->rch ,p ,val);
return pushup(x);
}
void add(LL a ,LL b){
// add a*(2^b)
int bit32 = b / 32 ,bit = b % 32;
LL f = a << bit;
if(f > UINT_MAX){
add((f & UINT_MAX) >> bit ,b);
add(f >> 32 ,(bit32 + 1) * 32);
return ;
}
// debug("%lld %lld\n" ,a ,b);
add(root ,bit32 ,f);
}
int __del(node *x ,int l ,int r){ // 找到最小的不是全 0 的 pos
// debug("Find [%d ,%d] ,allzero = %d\n" ,x->l ,x->r ,x->zero);
if(x->r < l || x->l > r) return 0;
if(x->zero) return 0;
if(x->l == x->r){
return add(x ,-1) ,x->l;
}
pushdown(x);
int ret = 0;
if(x->lch->zero || !(ret = __del(x->lch ,l ,r))){
ret = __del(x->rch ,l ,r);
}
pushup(x);
return ret;
}
void del(node *x ,int p ,LL val){
if(x->l == x->r){
if(x->num - val < 0){
x->num = x->num - val + UINT_MAX + 1;
add(x ,0);
int pos = __del(root ,p + 1 ,MX);
if(pos - 1 >= p + 1) cov(root ,p + 1 ,pos - 1 ,1);
}
else add(x ,-val);
return ;
}
pushdown(x);
if(p <= x->lch->r) del(x->lch ,p ,val);
else del(x->rch ,p ,val);
return pushup(x);
}
void sub(LL a ,LL b){
int bit32 = b / 32 ,bit = b % 32;
LL f = a << bit;
if(f > UINT_MAX){
sub((f & UINT_MAX) >> bit ,b);
sub(f >> 32 ,(bit32 + 1) * 32);
return ;
}
del(root ,bit32 ,f);
}
LL query(node *x ,int p){
if(x->l == x->r) return x->num;
pushdown(x);
if(p <= x->lch->r) return query(x->lch ,p);
return query(x->rch ,p);
}
int query(int pos){
int bit32 = pos / 32 ,bit = pos % 32;
return (query(root ,bit32) >> bit) & 1;
}
void output(node *x){
if(x->l == x->r){
for(int i = 0 ; i < 32 ; ++i){
debug("%u" ,(x->num >> i) & 1);
}
return;
}
pushdown(x);
output(x->lch) ,output(x->rch);
}
int main(){
__FILE([NOI2017]整数);
int n = read(); read() ,read() ,read();
root = build(0 ,MX);
for(LL i = 1 ,op ,a ,b ; i <= n ; ++i){
// debug("%d\n" ,i);
op = read();
if(op == 1){
a = read() ,b = read();
// assert(a >= 0);
if(a > 0) add(a ,b);
else sub(-a ,b);
}
else{
a = read();
printf("%d\n" ,query(a));;
}
// output(root);
// debug("\n");
}
}