P3302 [SDOI2013]森林
题意描述:
给你一个 \(n\) 个点 \(m\) 条边的森林,有 \(T\) 次操作,每次操作分为两种类型:
- 操作1:询问 \(x-y\) 路径上节点权值第 \(k\) 小的是多少。
- 操作2:连接 \(x-y\) 这一条边,保证连完之后还是森林。
数据范围:强制在线, \(n,m,T\leq 8e4\)
solution
先不考虑操作二,只看操作一的话直接树上主席树即可。
对于操作二,首先暴力合并是不可能的,考虑启发式合并,记录一下每棵树的大小,每次合并把小的合并到大的树当中,合并的时候 \(DFS\) 一下节点数量较少的树,在 \(DFS\) 的过程中维护一下倍增数组和主席树即可。
维护倍增数组的话一个比较好的实现方法就是边 \(DFS\) 边修改(个人认为这样还好些点)。
时间复杂度 \(O(nlog^2n)\) ,空间复杂度 \(O(nlogn)\)
然后这道题就完了 QAQ。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 2e5+10;
int n,m,u,v,id,t,k,x,y,cnt,tot,last,num;
int head[N],fa[N][25],rt[N*15],siz[N],dep[N],b[N],w[N];
struct node
{
int to,net;
}e[N<<1];
struct Tree
{
int lc,rc,sum;
}tr[40000010];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
void add(int x,int y)
{
e[++cnt].to = y;
e[cnt].net = head[x];
head[x] = cnt;
}
void up(int o)
{
tr[o].sum = tr[tr[o].lc].sum + tr[tr[o].rc].sum;
}
void insert(int &o,int last,int l,int r,int x,int val)
{
o = ++tot;
if(l == r)
{
tr[o].sum = tr[last].sum + val;
return;
}
tr[o].lc = tr[last].lc;
tr[o].rc = tr[last].rc;
int mid = (l + r)>>1;
if(x <= mid) insert(tr[o].lc,tr[last].lc,l,mid,x,val);
if(x > mid) insert(tr[o].rc,tr[last].rc,mid+1,r,x,val);
up(o);
}
void dfs(int x,int f)//dfs维护倍增数组和主席树
{
dep[x] = dep[f] + 1; fa[x][0] = f; siz[x] = 1;
for(int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i-1]][i-1];
insert(rt[x],rt[f],1,num,w[x],1);
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == f) continue;
dfs(to,x);
siz[x] += siz[to];
}
}
int lca(int x,int y)
{
if(dep[x] <= dep[y]) swap(x,y);
for(int i = 20; i >= 0; i--)
{
if(dep[fa[x][i]] >= dep[y])
{
x = fa[x][i];
}
}
if(x == y) return x;
for(int i = 20; i >= 0; i--)
{
if(fa[x][i] != fa[y][i])
{
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int query(int x,int y,int z,int u,int l,int r,int k)
{
if(l == r) return l;
int num = tr[tr[x].lc].sum + tr[tr[y].lc].sum - tr[tr[z].lc].sum - tr[tr[u].lc].sum;
int mid = (l + r)>>1;
if(k <= num) return query(tr[x].lc,tr[y].lc,tr[z].lc,tr[u].lc,l,mid,k);
else return query(tr[x].rc,tr[y].rc,tr[z].rc,tr[u].rc,mid+1,r,k-num);
}
int main()
{
id = read(); n = read(); m = read(); t = read();
for(int i = 1; i <= n; i++) b[i] = w[i] = read();
sort(b+1,b+n+1);
num = unique(b+1,b+n+1)-b-1;
for(int i = 1; i <= n; i++) w[i] = lower_bound(b+1,b+num+1,w[i])-b;
for(int i = 1; i <= m; i++)
{
u = read(); v = read();
add(u,v); add(v,u);
}
for(int i = 1; i <= n; i++) if(!dep[i]) dfs(i,0);
for(int i = 1; i <= t; i++)
{
char opt; cin>>opt;
if(opt == 'Q')
{
x = read() ^ last;
y = read() ^ last;
k = read() ^ last;
int Lca = lca(x,y);
last = b[query(rt[x],rt[y],rt[Lca],rt[fa[Lca][0]],1,num,k)];
printf("%d\n",last);
}
else
{
x = read() ^ last;
y = read() ^ last;
if(siz[x] < siz[y]) swap(x,y);
add(x,y); add(y,x);
siz[x] += siz[y];
dfs(y,x);
}
}
return 0;
}