[PA 2020] Trzy drogi
从 2023.4 鸽到 2024.3,终于过了这个题。
pjudge 题解虽然写了,但可能是 bot 写的,写的很不清楚。
根据经典做法,搜出一棵 dfs 树,对非树边赋随机权值,树边权值为跨过它的所有非树边的权值 xor。
那割三条边能割开的条件就是:选三条边的一个子集,这个子集中的边权 xor 为 0。
也就是存在 \(w_i = 0\) 或 \(w_i \oplus w_j = 0\) 或 \(w_i \oplus w_j \oplus w_k = 0\)。
考虑对图进行一些预处理,判掉一些简单情况。
首先判掉割边:对于一条割边,先计算割了它的答案,再把两个端点用并查集缩起来,表示以后再也不会割这条边。
处理完了图上不会有 \(w_i = 0\) 的边(也就是割边)。
接下来有一步炫酷操作。
先计算割两条边(再加随意一条边)就能把图割开的情况,也就是 \(w_i \oplus w_j = 0\) 的情况。只需要找到所有权值相等的边进行计算。
接下来的限制就是只有砍三条边,也就是 \(w_i \oplus w_j \oplus w_k = 0\) 的情况。
由于没有 \(w_i = 0\),那么 \(w_i,w_j,w_k\) 都是不同的,也就是同一权值的边只会割一条。
观察到相同权值的边割哪条效果都是一样的,可以对图进行如下处理:
给每条边加上“长度” \(l_i\),表示割这条边答案乘上几(也就是说砍三条边的话答案加上 \(l_i\times l_j\times l_k\))。把同一颜色的边中,选一条将长度设为边的数量,其他的长度设为 0。容易发现不改变答案。
对于长度为 0 的边,可以直接并查集把两个端点缩起来。
这样缩完以后,图上没有权值相同的边,并且因此图上没有度数 \(\le 2\) 的点,这对后面操作的复杂度证明有用。
接下来分类讨论砍三条边的情况。
零、三条非树边
由于树边一定使图联通,所以不可能。
一、一条树边,两条非树边
枚举这条树边,看看是否恰有两条非树边跨过这条树边。这部分容易统计。
二、两条树边,一条非树边
这种情况就是选树边 \(a,b\) 和非树边 \(c\),满足 \(c\) 覆盖了 \(a\) 但没有覆盖 \(b\),而且覆盖 \(a\) 和 \(b\) 的其他边集合相同。
不难发现,选的 \(a,b\) 需要是祖先-后代关系。
然后分两种情况讨论:\(a\) 比 \(b\) 更深或 \(a\) 比 \(b\) 更浅。下面称一条非树边更浅的一端为“上端点”,更深的一端为“下端点”。
\(a\) 比 \(b\) 更深的情况:对于覆盖了 \(a\) 的所有非树边,只能选择删“上端点最深”的边,不然这条边一定会使 \(a,b\) 的覆盖集合不相同。
\(a\) 比 \(b\) 更浅的情况:这时要考察覆盖了 \(a\) 的所有非树边的下端点。
下面进行一些分类讨论(可以自行手玩):
- 如果跨过 \(a\) 的边只有两条,那就可以任意删一条;
- 如果 \(\le 3\) 条:如果下端点全部成祖先-后代关系,那只能删“下端点最浅”的边;
- 否则就会有分叉,发现在只有两个分叉的时候可能可行,而且删完 \(c\) 后要满足剩下的下端点全部成祖先-后代关系。这就需要两组中其中一组只有一条边,并且删去那条边。
将所有情况归纳起来,可以发现一个优美的结论:只需要考虑覆盖了 \(a\) 的所有非树边的下端点中,dfs 序最小和最大的两条。
那么对于一个 \(a\),总共只有 \(3\) 个候选的 \(c\)。注意两种情况下,\(3\) 个候选的 \(c\) 可能有重复,需要去重。
把所有树边插进 map,作为 \(b\) 的候选,在 map 中查询权值等于 \(w_a\oplus w_c\) 的边即可。
两者都能用合并子树的可并堆来维护,每次合并子树的可并堆,然后弹掉上端点在子树内的边。一轮的时间复杂度为 \(O(m\log m)\)。
三、三条树边
此时的方案中,非树边一定不会被割。
所以可以将非树边两端的端点用并查集并起来,得到一张新图,在新图上继续做。
由于图中每个点的度数至少为 \(3\),因此 \(|E| \geq \frac{3}{2} |V|\),故被缩掉的边数至少为 \(|E| - |V| + 1 \ge \frac{1}{2}|V|\)。
于是缩的轮数只有 \(\log m\) 轮。在每轮中都统计前两种情况即可。
时间复杂度 \(O(m\log^2 m)\),代码实现十分繁琐。
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
typedef pair<int,int>pii;
typedef vector<int>vi;
#define maxn 500005
#define inf 0x3f3f3f3f
#define umap unordered_map<ull,int>
int n,m,msum;
ll res;
ll C2(int x){
return 1ll*x*(x-1)/2;
}
ll C3(int x){
return 1ll*x*(x-1)*(x-2)/6;
}
struct node{
int u,v,c;
}g[maxn];
int dsu[maxn];
int gf(int x){
while(x!=dsu[x])x=dsu[x]=dsu[dsu[x]];
return x;
}
// connect the edge.
bool cut[maxn];
int id[maxn];
void build(){
int nn=0,mm=0;
For(i,1,n)dsu[i]=i,id[i]=0;
For(i,1,m) if(cut[i]) dsu[gf(g[i].u)]=gf(g[i].v);
For(i,1,n) if(i==dsu[i] && !id[i]) id[i]=++nn;
msum=0;
For(i,1,m){
g[i].u=gf(g[i].u),g[i].v=gf(g[i].v);
g[i].u=id[g[i].u],g[i].v=id[g[i].v];
if(g[i].u!=g[i].v) g[++mm]=g[i],msum+=g[i].c;
}
n=nn,m=mm;
For(i,1,m) cut[i]=0;
}
mt19937_64 rnd(64);
int dep[maxn],dfn[maxn],idx;
int fa[maxn],fe[maxn];
int on[maxn];
ull h[maxn],s[maxn];
vector<pii>e[maxn];
void dfs(int u){
// cout<<"dfs "<<u<<"\n";
dfn[u]=++idx;
for(auto [v,i]:e[u]){
if(i==fe[u])continue;
if(!dfn[v]){
fa[v]=u,fe[v]=i,dep[v]=dep[u]+1;
on[i]=1;
// cout<<"on "<<i<<"\n";
dfs(v);
s[u]^=s[v];
}
else if(dep[u]>dep[v]) h[i]=rnd(),s[u]^=h[i],s[v]^=h[i];
}
}
void init()
{
idx=0;
For(i,1,n)e[i].clear(),dfn[i]=dep[i]=0,s[i]=0,fa[i]=fe[i]=0; idx=0;
For(i,1,m)on[i]=0,h[i]=0;
For(i,1,m){
int u=g[i].u,v=g[i].v;
e[u].pb(mkp(v,i)),e[v].pb(mkp(u,i));
}
dep[1]=1,dfs(1);
For(i,1,n) e[i].clear();
For(i,1,m) if(on[i]) {
int u=g[i].u,v=g[i].v;
if(dep[u]<dep[v])swap(u,v);
e[v].pb(mkp(u,i));
h[i]=s[u];
}
}
// for cut edges
void solve0()
{
init();
// For(i,1,m) cout<<h[i]<<" " ;cout<<" h\n";
umap mp;
For(i,1,m) {
if(!h[i]) --msum,res+=C2(msum),cut[i]=1,g[i].c=0;
else mp[h[i]]++;
}
For(i,1,m){
if(!h[i]) continue;
if(!mp.count(h[i])) cut[i]=1,g[i].c=0; // don't use the edge
else {
g[i].c=mp[h[i]];
// cout<<"h: "<<h[i]<<" "<<g[i].c<<"\n";
res+=C2(g[i].c)*(msum-g[i].c);
res+=C3(g[i].c);
mp.erase(h[i]);
}
}
build();
}
umap mp;
ll sum[maxn][3];
void dfs1(int u){
for(auto [v,i]:e[u]){
dfs1(v);
For(i,0,2)sum[u][i]+=sum[v][i];
}
if(sum[u][0]==2) res+=g[fe[u]].c*((sum[u][1]*sum[u][1]-sum[u][2])/2);
}
struct TR{
int rt[maxn],ls[maxn],rs[maxn],d[maxn];
int val[maxn];
int merge(int u,int v){
if(!u||!v)return u|v;
if(val[u]<val[v])swap(u,v);
rs[u]=merge(rs[u],v);
if(d[ls[u]]<d[rs[u]])swap(ls[u],rs[u]);
d[u]=d[ls[u]]+1;
return u;
}
void mg(int u,int v){
rt[u]=merge(rt[u],rt[v]);
}
void pop(int&u){
u=merge(ls[u],rs[u]);
}
void pp(int u,int lim){
while(rt[u] && dep[g[rt[u]].v]>=lim) pop(rt[u]);
}
void dfs(int u){
if(!u)return;
cout<<"x: "<<u<<" "<<val[u]<<"\n";
dfs(ls[u]),dfs(rs[u]);
}
void clear(){
For(i,1,n) rt[i]=0;
For(i,1,m) ls[i]=rs[i]=d[i]=val[i]=0;
}
}T,T2;
vector<pii>tmp;
void chk(int x,int y){
if(!x||!y)return;
tmp.pb(mkp(min(x,y),max(x,y)));
// cout<<"chk "<<x<<" "<<y<<"\n";
// if(x&&y&&mp.count(h[x]^h[y]))
// res+=g[x].c*g[y].c*mp[h[x]^h[y]];
}
void dfs2(int u){
// cout<<"dfs2 "<<u<<"\n";
for(auto [v,i]:e[u]){
dfs2(v);
T.mg(u,v);
}
T.pp(u,dep[u]);
if(T.rt[u] && fe[u]) /*cout<<"u:: "<<u<<"\n",T.dfs(T.rt[u]),*/chk(fe[u],T.rt[u]);
}
void dfs3(int u){
// cout<<"dfs3 "<<u<<"\n";
for(auto [v,i]:e[u]){
dfs3(v);
T.mg(u,v),T2.mg(u,v);
}
T.pp(u,dep[u]),T2.pp(u,dep[u]);
if(fe[u]) {
if(T.rt[u]) chk(fe[u],T.rt[u]);
if(T2.rt[u]) chk(fe[u],T2.rt[u]);
}
}
void solve1()
{
For(i,1,m) if(dep[g[i].u]<dep[g[i].v]) swap(g[i].u,g[i].v);
// For(i,1,m) cout<<g[i].u<<" "<<g[i].v<<" "<<g[i].c<<" "<<on[i]<<"\n"; puts("----------");
// 1 tree & 2 not
For(i,1,n) For(j,0,2) sum[i][j]=0;
For(i,1,m) if(!on[i]) {
int u=g[i].u,v=g[i].v,w=g[i].c;
sum[u][0]+=1,sum[v][0]-=1;
sum[u][1]+=w,sum[v][1]-=w;
sum[u][2]+=w*w,sum[v][2]-=w*w;
}
int lst=res;
dfs1(1);
T.clear();
For(i,1,m) if(!on[i]){
T.val[i]=dep[g[i].v];
T.rt[g[i].u]=T.merge(T.rt[g[i].u],i);
}
dfs2(1);
T.clear(),T2.clear();
For(i,1,m) if(!on[i]){
int x=g[i].u;
T.val[i]=dfn[x];
T.rt[x]=T.merge(T.rt[x],i);
T2.val[i]=-dfn[x];
T2.rt[x]=T2.merge(T2.rt[x],i);
}
dfs3(1);
// puts("qaq");
mp.clear();
For(i,1,m) if(on[i]) mp[h[i]]+=g[i].c;
sort(ALL(tmp));
tmp.erase(unique(ALL(tmp)),tmp.end());
for(auto [x,y]:tmp){
// cout<<"chk "<<x<<" "<<y<<" "<<mp[h[x]^h[y]]<<"\n";
if(mp.count(h[x]^h[y]))
res+=g[x].c*g[y].c*mp[h[x]^h[y]];
}
tmp.clear();
}
// for rest
void solve()
{
init();
// For(i,1,m)cout<<g[i].u<<" "<<g[i].v<<" "<<g[i].c<<"\n";
// For(i,1,m) cout<<h[i]<<" " ;cout<<" h\n";
// For(i,1,m) For(j,i+1,m) For(k,j+1,m)
// if((h[i]^h[j]^h[k])==0) res+=g[i].c*g[j].c*g[k].c; return;
while(n>1){
init();
solve1();
For(i,1,m) if(!on[i]) cut[i]=1,g[i].c=0;
build();
}
}
signed main()
{
n=read(),m=read(),msum=m;
For(i,1,n)dsu[i]=i;
For(i,1,m)g[i].u=read(),g[i].v=read(),g[i].c=1;
solve0();
solve();
cout<<res;
return 0;
}
/*
*/