LOJ #6144. 「2017 山东三轮集训 Day6」C 重建可持久化字典树
这个题目的Xor,And,Or都是对于所有的数字进行操作的,那么我们会发现对一个二进制位置具有破坏性的操作是只有And和Or的。那么我们定义什么操作叫做破坏性操作。
我们对于这种题肯定是能想到要用字典树来维护的,字典树的每个结点的左儿子是0,右儿子是1,那么当一个二进制位置具有两个儿子的时候,对于每个数字二进制上第x位而言。第x位And一个0会使第x位的右儿子合并到左儿子,如果第x位Or一个1的话会使左儿子合并到右儿子上(虽然不是真的要去合并,下文会提及),对于这种能使子树合并的操作这里称之为破坏性操作。而对于Xor呢则是可以认为是交换两个子树,并不是破坏性操作(这里也不是真的要去交换子树,下文会解释)。
所以这里我们考虑维护每个二进制位是否被破坏过,如果被破坏过我们就重新建一次树。值得一提的是已经被破坏的位置再次Or和And是不会导致子树合并,那么就不能叫破坏,这么一来我们最多重建31次树。被破坏过已经代表这个结点子树被合并过了,所以二进制上被破坏过的位置再次被And和Or是不会具有破坏性的,只会导致子树反转。
那么为什么叫子树反转呢,如果对于每个数字的二进制上第x位,破坏的时候是第x位And上了0,那么这个时候第x位的右子树被合并到了左子树。在下次第x位遇上了Or上了1的操作时,我们可以认为是第x位的左子树变成了右子树(这里也不是真的要去交换子树,下文会解释)。
那么如何维护是否二进制上每一位是否被破坏过呢,我们使用三个变量,sam, whl, rev。
sam:初始为0,维护二进制上每一位是否被破坏,对于二进制上的每一位而言,如果某一位为0,则是没有被破坏,如果某一位为1则是被破坏了。
whl:维护二进制上每一位被破坏的时候是被0破坏了还是被1破坏了,哪些位上有意义还得看一下sam对应位置是否为1。对于whl我们初始化为(1 << 31) - 1,也就是二进制下前31位全为1的数字。对于这个whl的二进制每一位的表示只有和sam在相同位置上为1的时候才有意义。假设sam为(1110),whl为(1011),大家都知道二进制最右边为最低位也就是第一位(还是防止一下有些萌新不知道),那么对于whl而言他的第一位没有意义,第二三四位才有意义。并且第二位和第四位被破坏的时候是被1破坏的,第三位被破坏的时候是被0破坏的。至于为什么sam二进制下为0的的时候whl对应位置没有意义也很显然。没有被破坏的位置,怎么能说它是被0破坏了还是被1破坏了...
rev:维护二进制上每一位是否被异或(翻转),如果为1则是需要翻转,为0则是不需要翻转。
那么如何进行维护呢?
每次进行And和Or的时候,提前用一个变量L维护,判断sam在被And和Or破坏之后是否和L相同,如果不同则重新建树,我们选择直接把每个数Or上sam进行重新建树实现,这么一来对于这些被破坏了的点,左子树会被合并到右子树上。对于And和Or的破坏操作维护起来也不同,我们设置一个all=(1<<31)-1。然后Or x的时候可以直接Or 上x然后和L进行判断,但是And x的时候需要把And转换为Or进行破坏,具体实现是用sam |= all ^ x(不要直接对x取反会弄出负数来),这样就可以交换x里面的0和1,转换成Or操作进行破坏。对于rev而言,每次进行Xor x,直接对rev ^= x即可,对于And和Or x操作的时候,我们可以认为x把rev破坏了,把被破坏的位置置为0即可。
那么如何对每个区间[y, x]进行查询k小呢,我们很容易想到可持久化字典树,可以维护以每个位置为前缀的所有信息的数据结构,每个位置的数字视为一个版本进行维护。每次插入的时候维护一下每个位置的左右子树被走过多少次。如果当前走到的位置没有被破坏,我们就看当前二进制位假设是第a位是否需要进行翻转,直接看(rev >> a) & 1是否为1即可,为1则是需要翻转左右子树。然后看两个位置[y, x]左子树(0)的走过数量的差,小于k就走右子树,反之走右子树(1)并且将答案ans += (1 << a)。如果被破坏那就只用走右子树,因为我们重建树是直接合并到右子树的。但是统计答案的时候需要看当前位是被0破坏还是被1破坏,然后再看是否需要反转。如果计算完这两步右子树还是应该还是在右子树,那么答案ans += 1 << a;
空间复杂度分析
根据题目数据范围n为5e4,然后每个数字维护的时候算上root需要32的空间,所以我们开5e4 * 35就够了。
时间复杂度分析
我们最多重建31次树,然后每次建树的时间复杂度差不多是5e4 * 32。所以差不多是5e8的时间复杂度,差不多是2s。然后对于每次查询是for 30-> 1的,所以差不多能在2s左右过掉这个题。
最后想说一下这个题用这个写法rev和whl还有sam的维护的东西可以变,以前写这道题看着题解写的,这次找不到之前那个自己看过的题解了,也不知道三个变量干嘛的,被迫自己想了很久,结果没想到三个变量维护的东西都变了。
所以还是鼓励大家自己多思考一下其他解法,只是照着题解写当时只觉得啊这么写就是好对,不自己独立思考过段时间就忘掉了,这也是我写博客园的初衷,只有自己彻底搞明白了才能明白写的明白题解,对于复习以前的题而言还是很有帮助的。同时也鼓励学弟写写题解
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#include <tuple>
#include <array>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define se second
#define AC main(void)
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
const int INF = 0x3f3f3f3f;
const int N = 2e5 + 10, M = 5e4 + 10;
int n, m, p;
int a[N];
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
int sam, whl, rev;
struct HisTire {
int ch[M * 35][2], ver[M * 35], root[M * 35], idx, sz[M * 35];
int sigma;
inline void Init() {
ver[0] = -1;//边界
idx = 0;
sigma = 30;//位数
ins(root[0], 0, 0, 0);//边界
}
inline void ins(int &xx, int o, int v, int pos) {
sz[xx = ++ idx] = sz[o] + 1;
ver[idx] = pos;
for (int d = sigma, c, x = xx; d >= 0; d --) {
c = (v >> d) & 1;
ch[x][c ^ 1] = ch[o][c ^ 1];
sz[x = ch[x][c] = ++ idx] = sz[o = ch[o][c]] + 1;
ver[idx] = pos;
}
}
//求出L到后面所有版本的异或最大值
inline int query(int x, int L, int cur) { //当前版本root[i] 左边界l - 1 查询值
int res = 0;
for (int i = sigma; i >= 0; i --) {
int c = cur >> i & 1;
if (ver[ch[x][!c]] >= L) {
x = ch[x][!c];
res += 1 << i;
} else
x = ch[x][c];
}
return res;
}
//查询区间k小
inline int queryk(int x, int y, int k) { //右端点root[r] 左端点root[l - 1] 第k小
int res = 0;
for (int i = sigma; ~i; i --) {
int lsx = ch[x][0], rsx = ch[x][1], lsy = ch[y][0], rsy = ch[y][1];
int p = (whl >> i) & 1;
p ^= (rev >> i) & 1;//p为1表示不用翻转,p为0表示需要翻转
if ((sam >> i) & 1) {
x = rsx, y = rsy;
if (p)
res += (1 << i);
continue;
}
if ((rev >> i) & 1) {
std::swap(lsx, rsx);
std::swap(lsy, rsy);
}
int ss = sz[lsx] - sz[lsy];
if (k <= ss)
x = lsx, y = lsy;
else {
k -= ss;
x = rsx, y = rsy;
res += (1 << i);
}
}
return res;
}
} HT;
inline void build() {
HT.Init();
for (int i = 1; i <= n; i ++) {
a[i] |= sam;
HT.ins(HT.root[i], HT.root[i - 1], a[i], i);
}
}
int all = (1 << 31) - 1;
inline void solve() {
std::cin >> n >> m;
whl = all;
for (int i = 1; i <= n; i ++)
std::cin >> a[i];
build();
int l, r, x;
auto &root = HT.root;
while (m --) {
std::string str;
std::cin >> str;
if (str[1] == 'o') { //xor
std::cin >> x;
rev ^= x;
} else if (str[1] == 'n') { //and
std::cin >> x;
l = sam, whl &= x, sam |= (all ^ whl), rev &= x;
if (sam != l)
build();
} else if (str[1] == 's') { //ask
std::cin >> l >> r >> x;
std::cout << HT.queryk(root[r], root[l - 1], x) << '\n';
} else if (str[1] == 'r') { //or
std::cin >> x;
l = sam, sam |= x, rev &= (all ^ x), whl |= x;
if (l != sam)
build();
}
}
}
signed AC{
HYS
int _ = 1;
//std::cin >> _;
while (_ --)
solve();
return 0;
}
最后再贴一下以前看题解写的代码,现在已经看不懂在干嘛了~~
#include <iostream>
#include <cstring>
#include <iomanip>
#include <algorithm>
#include <stack>
#include <queue>
#include <numeric>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <vector>
#include <unordered_set>
#include <cmath>
#include <map>
#include <unordered_map>
#include <set>
#include <deque>
#define all(a) a.begin(), a.end()
#define cnt0(x) __builtin_ctz(x)
#define endl '\n'
#define itn int
#define ll long long
#define ull unsigned long long
#define rep(i, a, b) for(int i = a;i <= b; i ++)
#define per(i, a, b) for(int i = a;i >= b; i --)
#define cntone(x) __builtin_popcount(x)
#define db double
#define fs first
#define se second
#define AC main(void)
#define re register
#define HYS std::ios::sync_with_stdio(false);std::cin.tie(0);std::cout.tie(0);
typedef std::pair<int, int > PII;
typedef std::pair<int, std::pair<int, int>> PIII;
typedef std::pair<ll, ll> Pll;
typedef std::pair<double, double> PDD;
using ld = double long;
const long double eps = 1e-9;
int d1[] = {0, 0, 1, -1};
int d2[] = {1, -1, 0, 0};
const int N = 1e5 + 10, M = 5e4 + 10;
const int INF = 0x3f3f3f3f;
int n, m;
int _ = 1;
int sam, rev, whl;
int a[M];
struct HisTire {
int ch[M * 25 * 30][2], ver[M * 25 * 30], root[M * 30], idx,
sz[M * 25 * 30];//25是固定的空间大小, 20是二进制最多多少位数
int sigma;
inline void Init() {
ver[0] = -1;//边界
idx = 0;
sigma = 30;//位数
insert(root[0], 0, 0, 0);//边界
}
//不带维护大小的插入 维护版本位置的插入
inline void insert(int &u, int y, int pos,
int cur) { //当前版本的根节点编号root[i] 上个版本的根节点编号root[i - 1] 当前是第几个版本
u = ++ idx;
int x = idx;
ver[x] = pos;
for (int i = sigma; i >= 0; i --) {
int c = cur >> i & 1;
ch[x][!c] = ch[y][!c];
ch[x][c] = ++ idx;
x = ch[x][c], y = ch[y][c];
ver[x] = pos;
}
}
//带维护大小的插入 不维护版本位置的插入
inline void ins(int &xx, int o, int
k) {//当前版本的根root[i] 上一个版本的根root[i - 1] 插入的值
sz[xx = ++ idx] = sz[o] + 1;
for (register int d = 30, c, x = xx; d >= 0; --d)
c = (k >> d) & 1, ch[x][c ^ 1] = ch[o][c ^ 1], sz[x = ch[x][c] = ++ idx] = sz[o = ch[o][c]] + 1;
}
//求出L到后面所有版本的异或最大值
inline int query(int x, int L, int cur) { //当前版本root[i] 左边界l - 1 查询值
int res = 0;
for (register int i = sigma; i >= 0; i --) {
int c = cur >> i & 1;
if (ver[ch[x][!c]] >= L) {
x = ch[x][!c];
res += 1 << i;
} else
x = ch[x][c];
}
return res;
}
//查询区间k小
inline int queryk(int x, int y, int k) { //右端点root[r] 左端点root[l - 1] 第k小
int res = 0;
for (int i = sigma; ~i; i --) {
if (sam >> i & 1) {
res += (whl & (1 << i)), x = ch[x][1], y = ch[y][1];
continue;
}
int lsx = ch[x][0], rsx = ch[x][1], lsy = ch[y][0], rsy = ch[y][1];
if ((rev >> i) & 1)
std::swap(lsx, rsx), std::swap(lsy, rsy);
int ss = sz[lsx] - sz[lsy];
if (k <= ss)
x = lsx, y = lsy;
else {
k -= ss;
x = rsx, y = rsy;
res += (1 << i);
}
}
return res;
}
} HT;
void re_build() {
for (int i = 1; i <= n; i ++)
a[i] = (a[i] ^ rev) | sam;
rev = 0;
HT.Init();
for (int i = 1; i <= n; i ++)
HT.ins(HT.root[i], HT.root[i - 1], a[i]);
}
void solve() {
HT.Init();
std::cin >> n >> m;
int all = (1 << 31) - 1;
for (int i = 1; i <= n; i ++) {
std::cin >> a[i];
HT.ins(HT.root[i], HT.root[i - 1], a[i]);
}
while (m --) {
std::string str;
int x, l, r, k;
std::cin >> str;
if (str[1] == 'n') {
std::cin >> x;
l = sam, whl &= x, sam |= (all ^ x);
if (l != sam)
re_build();
} else if (str[1] == 'r') {
std::cin >> x;
l = sam, sam |= x, whl |= x;
if (l != sam)
re_build();
} else if (str[1] == 'o') {
std::cin >> x;
rev ^= x, whl ^= x;
} else {
std::cin >> l >> r >> k;
std::cout << HT.queryk(HT.root[r], HT.root[l - 1], k) << '\n';
}
}
}
int AC{
HYS
//std::cin >> _;
while (_ --)
solve();
return 0;
}