codeforces 1444C Team Building (可持久化扩展域并查集)

题目链接:https://codeforces.com/problemset/problem/1444/C

第一想法就是暴力枚举所有两个点的子图然后判断二分图

正难则反,考虑补集转化,即不合法的方案数,有两种情况:

  1. 存在单色环
  2. 存在双色环

所以只需要统计出不含环的单一颜色的数量 \(cnt\) 和不合法的双色环数量 \(k\)
最终答案就是 $ \frac{cnt * (cnt - 1)}{2} - k$

统计答案使用并查集判断是否合法即可,
先将合法的单色边连好,然后依次加另一种颜色的边(边提前排好序保证相同颜色的被连续枚举),
每统计完一种颜色都要撤销掉这些操作(也即回到历史版本)
可持久化并查集即可

同时维护奇偶性需要用到扩展域并查集,将一个点拆成两个,\(x_self\)\(x_another\),
如果有边则将 \(x_self, y_another\)\(x_another, y_self\) 连到一起,表示两点不在同一队伍里,
如果 \(x_self, y_self\) 已经在同一集合里,则不合法

坑点:如果使用带权并查集的话,一定要路径压缩,而路径压缩在可持久化并查集中并不容易实现,所以选择扩展域并查集

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
using namespace std;
typedef long long ll;

const int maxn = 500010;

int n, m, k, cnt, tot; ll ans;
int c[maxn], rt[maxn * 5];

struct E{
	int u, v, uc, vc;
	bool operator < (const E &a) const {
		if(uc == a.uc){
			return vc < a.vc;
		}
		return uc < a.uc;
	}
}e[maxn];

int fa[maxn * 2], ran[maxn * 2];
int vis[maxn];

struct Node{
	int lc, rc;
	int fa, ran;
}t[maxn * 50];

int fi(int x){ return fa[x] == x ? x : fi(fa[x]); }

void unite(int x, int y){
	x = fi(x), y = fi(y);
	if(x == y) return;
	if(ran[x] < ran[y]){
		fa[x] = y;
	} else{
		fa[y] = x;
		if(ran[x] == ran[y]) ++ran[x]; 
	}
}

void solve_1(){
	for(int i = 1 ; i <= n + n ; ++i) fa[i] = i, ran[i] = 0;
	for(int i = 1 ; i <= m ; ++i){
		if(e[i].uc == e[i].vc && !vis[e[i].vc]){
			int x_self = fi(e[i].u), x_ano = fi(e[i].u + n);
			int y_self = fi(e[i].v), y_ano = fi(e[i].v + n); 
			if(x_self == y_self){
				--cnt;
				vis[e[i].uc] = 1;
			} else{
				unite(e[i].u, e[i].v + n);
				unite(e[i].v, e[i].u + n);
			}
		}
	}
}

void build(int &i, int l, int r){
	i = ++tot;
	if(l == r){
		t[i].fa = l;
		t[i].ran = 0;
		return;
	}
	int mid = (l + r) >> 1;
	build(t[i].lc, l, mid);
	build(t[i].rc, mid + 1, r);
}

void modify(int &i, int k, int p, int l, int r){
	t[++tot] = t[i];
	i = tot;
	if(l == r){
		t[i].fa = k;
		return;
	}
	int mid = (l + r) >> 1;
	if(p <= mid) modify(t[i].lc, k, p, l, mid);
	else modify(t[i].rc, k, p, mid + 1, r);
}

int query(int i, int p, int l, int r){
	if(l == r) return i; // 返回节点编号 
	int mid = (l + r) >> 1;
	if(p <= mid) return query(t[i].lc, p, l, mid);
	else return query(t[i].rc, p, mid + 1, r);
}

//并查集 

void add(int i, int p, int l, int r){ 
	if(l == r){
		++t[i].ran;
		return;
	}
	int mid = (l + r) >> 1;
	if(p <= mid) add(t[i].lc, p, l, mid);
	else add(t[i].rc, p, mid + 1, r);
}

int find(int v, int x){
	int ff = query(v, x, 1, n + n);
	if(t[ff].fa == x) return ff;
	return find(v, t[ff].fa);
}

void uni(int &v, int x, int y){
	x = find(v, x), y = find(v, y);
	if(t[x].ran < t[y].ran){
		modify(v, t[y].fa, t[x].fa, 1, n + n);
	} else{
		modify(v, t[x].fa, t[y].fa, 1, n + n);
		if(t[x].ran == t[y].ran){
			add(v, t[x].fa, 1, n + n);
		}
	}
}

void solve_2(){
	tot = 0;
	build(rt[0], 1, n + n);
	
	int ver = 0, his;

	for(int i = 1 ; i <= m ; ++i){ // 先把同色的合法并查集连接起来 
		if(e[i].uc == e[i].vc && !vis[e[i].uc]){
			++ver;
			rt[ver] = rt[ver - 1];
			
			uni(rt[ver], e[i].u, e[i].v + n);
			uni(rt[ver], e[i].v, e[i].u + n);
		}
	}

//		for(int i = 1 ; i <= n ; ++i){
//			int x = find(rt[ver], i);
//			printf("%d ", t[x].fa);
//		} printf("\n");	

	his = ver;
	int flag;
	for(int i = 1 ; i <= m ; ++i){
		if(e[i].uc == e[i].vc || vis[e[i].uc] || vis[e[i].vc]) continue;
		++ver;
		if(!(e[i].uc == e[i - 1].uc && e[i].vc == e[i - 1].vc)){
			rt[ver] = rt[his];
			flag = 0;
		} else {
			if(flag) continue;
			rt[ver] = rt[ver - 1];
		}
		
		int p_self = find(rt[ver], e[i].u), p_ano = find(rt[ver], e[i].u + n);
		int q_self = find(rt[ver], e[i].v), q_ano = find(rt[ver], e[i].v + n);
		if(t[p_self].fa == t[q_self].fa){
			--ans;
			flag = 1;
		} else {
			uni(rt[ver], e[i].u, e[i].v + n);
			uni(rt[ver], e[i].v, e[i].u + n);
		}
	}

}

ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }

int main(){
	n = read(), m = read(), k = read(); cnt = k;
	
	for(int i = 1 ; i <= n ; ++i) c[i] = read();
	
	for(int i = 1 ; i <= m ; ++i){
		e[i].u = read(), e[i].v = read();
		e[i].uc = c[e[i].u], e[i].vc = c[e[i].v];
		if(e[i].uc > e[i].vc){
			swap(e[i].u, e[i].v);
			swap(e[i].uc, e[i].vc);
		}
	}
	
	sort(e + 1, e + 1 + m);
//	for(int i = 1 ; i <= m ; ++i){
//		printf("%d %d %d %d\n", e[i].u, e[i].v, e[i].uc, e[i].vc);
//	}
	
	solve_1();
	
	ans = 1ll * (cnt - 1) * cnt / 2;
//	printf("%d\n", ans);

	solve_2();
	
	printf("%lld\n", ans);
	
	return 0;
}
posted @ 2020-11-18 22:04  Tartarus_li  阅读(205)  评论(1编辑  收藏  举报