洛谷P3402 【模板】可持久化并查集
题目描述
n个集合 m个操作
操作:
1 a b 合并a,b所在集合
2 k 回到第k次操作之后的状态(查询算作操作)
3 a b 询问a,b是否属于同一集合,是则输出1否则输出0
输入输出格式
输入格式:
输出格式:
输入输出样例
输入样例#1:
5 6
1 1 2
3 1 2
2 0
3 1 2
2 1
3 1 2
输出样例#1:
1
0
1
说明
\(1 \le n \le 10^5, 1 \le m \le 2 \times 10^5\)
By zky 出题人大神犇
题解
可持久化线段树维护fa和rank/size
按秩合并/启发式合并
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
inline int max(int a, int b){return a > b ? a : b;}
inline int min(int a, int b){return a < b ? a : b;}
inline void swap(int &x, int &y){int tmp = x;x = y;y = tmp;}
inline void read(int &x)
{
x = 0;char ch = getchar(), c = ch;
while(ch < '0' || ch > '9') c = ch, ch = getchar();
while(ch <= '9' && ch >= '0') x = x * 10 + ch - '0', ch = getchar();
if(c == '-') x = -x;
}
const int INF = 0x3f3f3f3f;
const int MAXN = 400000 + 10;
struct Node
{
int ls, rs, f, rank;
}node[MAXN * 40];
int tot, n, m, fa[MAXN], now, pos[MAXN], cnt;
void build(int &o, int l = 1, int r = n)
{
o = ++ tot;
if(l == r)
{
node[o].f = l;
return;
}
int mid = (l + r) >> 1;
build(node[o].ls, l, mid);
build(node[o].rs, mid + 1, r);
}
void insert(int &o, int oo, int p, int k, int l = 1, int r = n)
{
o = ++ tot;
if(l == r)
{
node[o].f = k;
node[o].rank = node[oo].rank;
return;
}
int mid = (l + r) >> 1;
if(p <= mid) insert(node[o].ls, node[oo].ls, p, k, l, mid), node[o].rs = node[oo].rs;
else insert(node[o].rs, node[oo].rs, p, k, mid + 1, r), node[o].ls = node[oo].ls;
}
void add(int o, int p, int l = 1, int r = n)
{
if(l == r)
{
++ node[o].rank;
return;
}
int mid = (l + r) >> 1;
if(p <= mid) add(node[o].ls, p, l, mid);
else add(node[o].rs, p, mid + 1, r);
}
int ask_f(int o, int p, int l = 1, int r = n)
{
if(l == r) return node[o].f;
int mid = (l + r) >> 1;
if(p <= mid) return ask_f(node[o].ls, p, l, mid);
else return ask_f(node[o].rs, p, mid + 1, r);
}
int ask_rank(int o, int p, int l = 1, int r = n)
{
if(l == r) return node[o].rank;
int mid = (l + r) >> 1;
if(p <= mid) return ask_rank(node[o].ls, p, l, mid);
else return ask_rank(node[o].rs, p, mid + 1, r);
}
int find(int k, int x)
{
int tmp;
while(true)
{
tmp = ask_f(fa[k], x);
if(tmp == x) return x;
x = tmp;
}
}
void merge(int k, int x, int y)
{
x = find(k - 1, x), y = find(k - 1, y);
if(x == y) return;
int rank_x = ask_rank(fa[k - 1], x), rank_y = ask_rank(fa[k - 1], y);
if(rank_x < rank_y) insert(fa[k], fa[k - 1], x, y);
else insert(fa[k], fa[k - 1], y, x);
if(rank_x == rank_y) add(fa[k], x);
}
int main()
{
read(n), read(m);
build(fa[0]);
for(int i = 1;i <= m;++ i)
{
int tmp1,tmp2,tmp3;read(tmp1);
if(tmp1 == 1)
{
read(tmp2), read(tmp3);
fa[i] = fa[i - 1];
merge(i, tmp2, tmp3);
}
else if(tmp1 == 2)
{
read(tmp2);
fa[i] = fa[tmp2];
}
else
{
read(tmp2), read(tmp3);
fa[i] = fa[i - 1];
printf("%d\n", find(i, tmp2) == find(i, tmp3));
}
}
return 0;
}