LOJ3462. 「WC2021」括号路径
一张图,给出若干个三元组\((u,v,w)\),表示有\(u\to v\)的颜色为\(w\)的左括号和\(v\to u\)的颜色为\(w\)的右括号。询问多少对\((u,v),u<v\)使得\(u\)到\(v\)存在路径使得这个路径上的括号序合法。
\(n\le 3*10^5,m\le 6*10^5\)
在比赛最后一个小时前一直在想怎么前缀和不小于零……于是连暴力都不会写……
后来忽然发现可以记\(f_{u,v}\)表示\(u\)可以到\(v\),然后暴力扩展。估分32,实际水了48。
比赛之后才发现这个东西是双向的,也就是说\(f_{u,v}=f_{v,u}\)。
于是如果\(f_{u,v}=1\),可以把\(u,v\)看成等价类。
用并查集+启发式合并维护搞出所有等价类即可。
具体:把等价类缩点,并且记录这些点的每种颜色的入边。把点丢到一个队列里,操作某个点的时候,找到所有入边大于等于\(2\)的边,将这些边连向的点缩起来,然后丢到队列中。
注意细节。
using namespace std;
#include <bits/stdc++.h>
#define N 300005
#define M 600005
#define ll long long
int n,m,k;
int dsu[N];
int getdsu(int x){return dsu[x]==x?x:dsu[x]=getdsu(dsu[x]);}
map<int,vector<int> > e[N];
set<int> b[N];
int sz[N];
queue<int> q;
bool inq[N];
vector<int> o[N];
bool cmpo(int x,int y){return sz[x]>sz[y];}
void merge(int x,int y){
for (auto i=e[y].begin();i!=e[y].end();++i){
auto t=&e[x][i->first];
for (auto j=i->second.begin();j!=i->second.end();++j)
t->push_back(*j);
if (t->size()>=2)
b[x].insert(i->first);
}
e[y].clear();
dsu[y]=x;
sz[x]+=sz[y];
}
void BFS(){
for (int i=1;i<=n;++i)
q.push(i),inq[i]=1;
while (!q.empty()){
int x=q.front();
q.pop();
inq[x]=0;
x=getdsu(x);
int cnt=0;
for (auto i=b[x].begin();i!=b[x].end();++i){
auto t=&e[x][*i];
for (auto j=t->begin();j!=t->end();++j)
o[cnt].push_back(*j);
cnt++;
for (int j=t->size();j>=2;--j)
t->pop_back();
}
b[x].clear();
static int bz[N],BZ;
for (int i=0;i<cnt;++i){
++BZ;
int mx=0;
for (int j=0;j<o[i].size();++j){
int y=getdsu(o[i][j]);
if (sz[y]>sz[mx]) mx=y;
}
for (int j=0;j<o[i].size();++j){
int y=getdsu(o[i][j]);
if (y==mx) continue;
merge(mx,y);
}
if (!inq[mx])
inq[mx]=1,q.push(mx);
o[i].clear();
}
}
}
int main(){
freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
scanf("%d%d%d",&n,&m,&k);
sz[0]=-1;
for (int i=1;i<=n;++i)
dsu[i]=i;
for (int i=1;i<=m;++i){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
e[v][w].push_back(u);
if (e[v][w].size()==2)
b[v].insert(w);
sz[v]++;
}
BFS();
static int cnt[N];
ll ans=0;
for (int i=1;i<=n;++i)
cnt[getdsu(i)]++;
for (int i=1;i<=n;++i)
if (dsu[i]==i)
ans+=(ll)cnt[i]*(cnt[i]-1)>>1;
printf("%lld\n",ans);
return 0;
}