H - Set【18南京网络赛】
H - Set
题意:
有 \(n\) 个集合,初始时第 \(i\) 个集合中的数只有 \(a_i\)
支持三种操作
1 u v
若第 \(u\) 个数和第 \(v\) 个数在不同的集合中,则将这两个集合合并
2 u
把第 \(u\) 个数所在的集合中所有的数都+1
3 u k x
询问操作,你需要输出第 \(u\) 个数所在的集合中满足 \(t = x \ \ (mod\ \ 2^k)\) 的 \(t\) 的个数
思路:
其实就是用 Tire
树来存储集合里的数,因为只需要合并,直接合并就好。
对于操作二,只需要交换左子树和右子树,这样左子树就相当于全部加上了 1
,然后递归处理左子树,实在是妙
在插入与合并的时候维护节点的子树大小信息,在查询操作的时候直接输出子树大小就可以
这个题当时写的时候 debug
滴了一下午
void merge(int rt1,int rt2){
// 暴力合并好像就可以
// 把 rt1 合并到 rt2
tr[rt2].size += tr[rt1].size;
if(tr[rt1].s[0]){
if (!tr[rt2].s[0]) tr[rt2].s[0] = tr[rt1].s[0];
else merge(tr[rt1].s[0], tr[rt2].s[0]);
}
if(tr[rt1].s[1]){
if (!tr[rt2].s[1])tr[rt2].s[1] = tr[rt1].s[1];
else merge(tr[rt1].s[1], tr[rt2].s[1]);
}
}
这个是正确的版本,我当时写的艾斯比版本如下
void merge(int rt1,int rt2){
// 暴力合并好像就可以
// 把 rt1 合并到 rt2
tr[rt2].size += tr[rt1].size;
if(tr[rt1].s[0]){
if (!tr[rt2].s[0]) tr[rt2].s[0] = getnode();
merge(tr[rt1].s[0], tr[rt2].s[0]);
}
if(tr[rt1].s[1]){
if (!tr[rt2].s[1])tr[rt2].s[1] = getnode();
merge(tr[rt1].s[1], tr[rt2].s[1]);
}
}
一直 段错误, 在合并的时候加了一个判断大小交换开大数据就过了 3/5
个点,我才意识到 段错误是因为 getnode
调用太多了
我一直奇怪我这样的写法本应更节省内存才对
最后只需要 200Mb
内存就可以
#include<bits/stdc++.h>
using namespace std;
const int N = 6e5 + 10;
struct node{
int s[2], size;
}tr[N*30];
int tot, root[N];
int getnode(){
tot++;
tr[tot].s[0] = tr[tot].s[1] = 0;
tr[tot].size = 0;
return tot;
}
int fa[N];
int find(int a){
return a == fa[a] ? a : fa[a] = find(fa[a]);
}
void insert(int rt,int val){
int cur = rt;
bitset<30>s(val);
tr[cur].size++;
for (int i = 0;i < 30;i++) {
int v = s[i];
if (!tr[cur].s[v])tr[cur].s[v] = getnode();
cur = tr[cur].s[v];
tr[cur].size++;
}
}
void merge(int rt1,int rt2){
// 暴力合并好像就可以
// 把 rt1 合并到 rt2
tr[rt2].size += tr[rt1].size;
if(tr[rt1].s[0]){
if (!tr[rt2].s[0]) tr[rt2].s[0] = tr[rt1].s[0];
else merge(tr[rt1].s[0], tr[rt2].s[0]);
}
if(tr[rt1].s[1]){
if (!tr[rt2].s[1])tr[rt2].s[1] = tr[rt1].s[1];
else merge(tr[rt1].s[1], tr[rt2].s[1]);
}
}
void add(int rt){
if (!rt)return;
swap(tr[rt].s[0], tr[rt].s[1]);
add(tr[rt].s[0]);
}
int query(int rt,int k,int x){
bitset<30>s(x);
int cur = rt;
for(int i = 0;i < k;i++){
if (!tr[cur].s[s[i]])return 0;
cur = tr[cur].s[s[i]];
}
return tr[cur].size;
}
int n, m;
int main(){
scanf("%d%d", &n, &m);
for(int i = 1;i <= n;i++){
fa[i] = i;root[i] = getnode();
int x;scanf("%d", &x);
insert(root[i], x);
}
while(m--){
int op, u, v, k, x;
scanf("%d", &op);
if(op == 1){
scanf("%d%d", &u, &v);
u = find(u);
v = find(v);
if (u == v)continue;
//if (tr[root[u]].size > tr[root[v]].size)swap(u, v);
merge(root[u], root[v]);
fa[u] = v;
}
else if(op == 2){
scanf("%d", &u);
u = find(u);
add(root[u]);
}
else{
scanf("%d%d%d", &u, &k, &x);
u = find(u);
printf("%d\n", query(root[u], k, x));
}
}
}