Codeforces 1140F 线段树 分治 并查集
题意及思路:https://blog.csdn.net/u013534123/article/details/89010251
之前cf有一个和这个相似的题,不过那个题只有合并操作,没有删除操作,直接并查集搞一搞就行了。对于这个题,因为有删除操作,我们对操作序列建一颗线段树,记录每个操作影响的区间操作就可以了。这里的并查集不能路径压缩,要按秩合并,这样复杂度是O(logn)的。
代码:
#include <bits/stdc++.h> #define ls (o << 1) #define rs (o << 1 | 1) #define INF 0x3f3f3f3f #define db double #define pii pair<int, int> #define LL long long using namespace std; const int maxn = 300010; const int Base = 300000; vector<pii> tr[maxn * 4]; map<pii, int> mp; map<pii, int>::iterator it; LL res[maxn], cnt_x[maxn * 2], cnt_y[maxn * 2], sz[maxn * 2]; int f[maxn * 2]; LL ans; pii a[maxn]; int get(int x) { if(x == f[x]) return x; return get(f[x]); } void add(int o, int l, int r, int ql, int qr, pii val) { if(l >= ql &&r <= qr) { tr[o].push_back(val); return; } int mid = (l + r) >> 1; if(ql <= mid) add(ls, l, mid, ql, qr, val); if(qr > mid) add(rs, mid + 1, r, ql, qr, val); } void del(int x, int y) { int x1 = get(x), y1 = get(y); if(x1 != y1) return; ans -= cnt_x[x] * cnt_y[x]; cnt_x[x] -= cnt_x[y], cnt_y[x] -= cnt_y[y]; sz[x] -= sz[y]; ans += cnt_x[x] * cnt_y[x]; ans += cnt_x[y] * cnt_y[y]; f[y] = y; } pii merge(int x, int y) { int x1 = get(x), y1 = get(y); if(x1 == y1) return make_pair(-1, -1); if(sz[x1] < sz[y1]) swap(x1, y1); ans -= cnt_x[x1] * cnt_y[x1]; ans -= cnt_x[y1] * cnt_y[y1]; sz[x1] += sz[y1]; cnt_x[x1] += cnt_x[y1], cnt_y[x1] += cnt_y[y1]; ans += cnt_x[x1] * cnt_y[x1]; f[y1] = x1; return make_pair(x1, y1); } void dfs(int o, int l, int r) { if(l == 12) { l++; l--; } stack<pii> s; for (auto x : tr[o]) { pii tmp = merge(x.first, x.second); if(tmp.first != -1) s.push(tmp); } if(l == r) res[l] = ans; else { int mid = (l + r) >> 1; dfs(ls, l, mid); dfs(rs, mid + 1, r); } while(!s.empty()) { del(s.top().first, s.top().second); s.pop(); } } int main() { int n, x, y; scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%d%d", &x, &y); y += Base; a[i] = make_pair(x, y); if(mp.find(a[i]) == mp.end()) mp[a[i]] = i; else { add(1, 1, n, mp[a[i]], i - 1, a[i]); mp.erase(a[i]); } } for (it = mp.begin(); it != mp.end(); it++) { add(1, 1, n, it -> second, n, it -> first); } for (int i = 1; i <= Base; i++) { f[i] = i, cnt_x[i] = 1, cnt_y[i] = 0, sz[i] = 1; } for (int i = Base + 1; i <= Base * 2; i++) { f[i] = i, cnt_x[i] = 0, cnt_y[i] = 1, sz[i] = 1; } dfs(1, 1, n); for (int i = 1; i <= n; i++) printf("%lld ", res[i]); }