sol-cf1989F
非常 Educational 的好题。
题目条件非常难看,考虑转化为如下形式:
- 对于每一行每一列建立一个虚点表示该行该列。总点数 \(n+m\)。
- 如果要求 \(a_{i,j}\) 为红色,则 \(i\) 向 \(j+n\) 连一条有向边,否则 \(j+n\) 向 \(i\) 连一条有向边。
- 容易发现,如果图是一个 DAG 的话那么我们显然只需要按照拓扑序依次进行操作即可。
- 于是转化为计数图中每个点数大于一的强连通分量的点数平方和。需要支持每次加一条边。
考虑直接做的复杂度,每一次加一条边都暴力地区跑一遍 Tarjan 计算一下强连通块复杂度 \(O(q(n+m))\) 十分垃圾。
注意到复杂度主要是在我们每加一次边都要遍历全图,这非常劣。事实上可能有的边在后面的计算里是完全没有用的。于是我们考虑整体二分。
对于每一条边,我们需要注意到的事情是:。
- 如果在加这一条边之前两侧的点已经处于同一个强连通分量之中,那么这条边没用。
- 否则它可能在之后的某个时间有用,和其它的边构成一个新的强连通分量。
这个性质非常好,非常的像二分。于是我们整体二分当前的时间点属于 \([l,r]\) 并记录此时可能有用的边。
具体地,我们每次加入时间为 \([l,mid]\) 的边然后暴力跑 Tarjan 去计算强连通分量。首先,根据上面的观察,对于时间在 [l, mid] 并且加了之后 \(u,v\) 联通一条边 \(u \to v\),显然要把它扔到 \([l,mid]\) 里面去计算。否则按照强连通分量缩点后扔到 \([mid+1,r]\) 里面去处理(这里有一个简便的实现是直接把这条边的 \(u,v\) 变成相对应的强连通分量的标号)。
对于所有 \(l=r\) 的时刻。显然的,此时做完一遍 Tarjan 后两边节点属于一个强连通分量的边都是会在时刻 \(l\) 产生新的强连通分量的边。于是我们呢对于每一个询问 \(i\) 记录下此时需要加的边。根据前面的推理,这些边加了之后两侧一定会合并成新的强连通分量。于是我们做完整体二分之后再直接用并查集维护一下每个强连通分量以及大小即可。
// Problem: Simultaneous Coloring
// URL: https://www.luogu.com.cn/problem/CF1989F
// Written by WRuperD
#include<bits/stdc++.h>
using namespace std;
const long long inf = 1e18;
const int mininf = 1e9 + 7;
#define int long long
#define pb emplace_back
inline int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}return x*f;}
inline void write(int x){if(x<0){x=~(x-1);putchar('-');}if(x>9)write(x/10);putchar(x%10+'0');}
#define put() putchar(' ')
#define endl puts("")
const int MAX = 4e5 + 10;
struct node{
int u, v, tim;
};
int dfn[MAX], low[MAX], id[MAX], ins[MAX], stc[MAX], clk, topp;
vector <int> g[MAX];
int scc;
void cov(int u){
g[u].clear();
dfn[u] = low[u] = id[u] = ins[u] = 0;
}
void dfs(int u){
dfn[u] = low[u] = ++clk;
stc[++topp] = u;
ins[u] = 1;
for(int v : g[u]){
if(!dfn[v]){
dfs(v);
low[u] = min(low[u], low[v]);
}else if(ins[v]){
low[u] = min(low[u], dfn[v]);
}
}
if(low[u] == dfn[u]){
++scc;
int v;
do{
v = stc[topp--];
id[v] = scc;
ins[v] = 0;
}while(v != u);
}
}
vector <int> Ans[MAX];
void solve2(int l, int r, vector <node> G){
if(l > r) return ;
int mid = (l + r) >> 1;
topp = clk = 0;
scc = 0;
for(auto U : G){
cov(U.u), cov(U.v);
}
for(auto U : G){
if(U.tim <= mid){
g[U.u].pb(U.v);
}
}
for(auto U : G){
if(!dfn[U.u]) dfs(U.u);
}
if(l == r){
for(auto U : G){
if(id[U.u] == id[U.v]){
Ans[l].pb(U.tim);
}
}
return ;
}
vector <node> G1, G2;
for(auto U : G){
if(id[U.u] == id[U.v]){
if(U.tim <= mid) G1.pb(U);
}else{
G2.pb(node{id[U.u], id[U.v], U.tim});
}
}
solve2(l, mid, G1), solve2(mid + 1, r, G2);
}
int ret;
int fa[MAX], siz[MAX];
int find(int x){
if(fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
void merge(int x, int y){
if(find(x) == find(y)) return ;
if(siz[find(x)] > 1) ret -= siz[find(x)] * siz[find(x)];
if(siz[find(y)] > 1) ret -= siz[find(y)] * siz[find(y)];
siz[find(x)] += siz[find(y)];
siz[find(y)] = 0;
fa[find(y)] = find(x);
ret += siz[find(x)] * siz[find(x)];
}
void solve(){
int n = read(), m = read(), q = read();
vector <node> G;
for(int i = 1; i <= q; i++){
int u = read(), v = read();
char ch = getchar(); while(ch != 'R' and ch != 'B'){ch = getchar();}
if(ch == 'R'){
G.pb(node{u, v + n, i});
}else{
G.pb(node{v + n, u, i});
}
}
solve2(1, q, G);
for(int i = 1; i <= n + m; i++) fa[i] = i, siz[i] = 1;
for(int i = 1; i <= q; i++){
for(auto u : Ans[i]){
merge(G[u - 1].u, G[u - 1].v);
}
write(ret), endl;
}
}
signed main(){
int t = 1;
while(t--) solve();
return 0;
}