【MX-J3-T3+】Tuple+ 题解
一个比较自然的思路就是对于每个三元组 \((u_i,v_i,w_i)\),把 \((v_i,w_i)\) 这个二元组放在属于 \(u_i\) 的 vector 里面。然后对于每一个 \(i\in [1,n-3]\),把 \(i\) 的 vector 里面的所有二元组 \((x,y)\) 当作一条连接 \(x,y\) 的无向边,则我们的目的是在图中找出所有的三元环 \((p_1,p_2,p_3)\),如果这个三元组被给出了,我们就让它与 \(i\) 组成一个合法的四元组 \((i,p_1,p_2,p_3)\)(\(i<p_1<p_2<p_3\))。
无向图找三元环的话就是一个比较模板的题目了。我们考虑给每条无向边定向,对于 \((x,y)\) 这条边,设 \(d_i\) 表示 \(i\) 的度数,若 \(d_x>d_y\) 或者 \(d_x=d_y\wedge x>y\),我们就定向为 \(x\to y\)。
然后我们就枚举每一个点 \(x\),然后枚举 \(x\) 指向的某个点 \(y\),再枚举 \(y\) 指向的另一个点 \(z\),若 \((x,z)\) 存在连边,则 \((x,y,z)\) 构成一个三元环。时间复杂度可以分类讨论证明一下:设边数为 \(M\),若 \(d_x\leq \sqrt{M}\),则 \(d_z\leq d_y\leq d_x\leq \sqrt{M}\),那么此次枚举就是 \(O(\sqrt{M})\) 的;若 \(d_x>\sqrt{M}\),则此次枚举最劣是 \(O(M)\) 的,而度数大于 \(\sqrt{M}\) 的点最多只有 \(\sqrt{M}\) 个,所以综合起来的复杂度就是 \(O(M\sqrt{M})\)。
由此,我们可以在规定的时间内找出所有的三元环,接着判一下每个三元环是否被给出即可。这里最好是先找完所有三元环,然后再一次性的挨个判断,不然判断时的 \(O(\log m)\) 可能会堆积在 \(O(M\sqrt{M})\) 这个时间复杂度上。
对于我们枚举的 \(i\),设其 vector 中存了 \(M_i\) 条边,则有 \(\sum M_i=m\),我们的总时间复杂度为 \(O(\sum M_i\sqrt{M_i})\)。
结论:\(a\sqrt{a}+b\sqrt{b}<(a+b)\sqrt{a+b}\),其中 \(a,b>0\)。
证明:将两边同时平方可得 \(a^3+b^3+2ab\sqrt{ab}<a^3+b^3+3ab^2+3a^2b\)。去掉相同项就能化简为 \(\sqrt{ab}<1.5(a+b)\),根据均值不等式 \(a+b\geq 2\sqrt{ab}\) 即可证明 \(\sqrt{ab}<1.5(a+b)\) 成立。
拓展一下这个结论可以知道 \(\sum M_i\sqrt{M_i}< (\sum M_i)\sqrt{\sum M_i}=m\sqrt{m}\)。因此我们的最劣时间复杂度为 \(O(m\sqrt{m})\)。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int MAXN=3e5+5;
int n,m;
int u[MAXN],v[MAXN],w[MAXN];
vector<pair<int,int> > vec[MAXN];
int stk[MAXN<<1],cnt;
int in[MAXN];
int head[MAXN],nxt[MAXN<<1],to[MAXN<<1],tot;
void add(int x,int y)
{
to[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
vector<int> v1[MAXN];
int vis[MAXN];
map<pair<pair<int,int>,int>,int> mp;
vector<pair<pair<int,int>,int> > temp;
int main()
{
cin>>n>>m;
for(int i=1;i<=m;i++) cin>>u[i]>>v[i]>>w[i],vec[u[i]].push_back(make_pair(v[i],w[i]));
int res=0;
for(int s=1;s<=n-3;s++)
{
if(!vec[s].size()) continue;
cnt=0,tot=0;
for(auto j:vec[s]) stk[++cnt]=j.first,stk[++cnt]=j.second,add(j.first,j.second),add(j.second,j.first);
for(auto j:vec[s]) in[j.first]++,in[j.second]++;
sort(stk+1,stk+cnt+1),cnt=unique(stk+1,stk+cnt+1)-stk-1;
for(int i=1;i<=cnt;i++)
{
for(int j=head[stk[i]];j;j=nxt[j])
{
if(in[stk[i]]>in[to[j]]||in[stk[i]]==in[to[j]]&&stk[i]>to[j]) v1[stk[i]].push_back(to[j]);
}
}
temp.clear();
for(int i=1;i<=cnt;i++)
{
for(int j:v1[stk[i]]) vis[j]=1;
for(int j:v1[stk[i]])
{
for(int k:v1[j])
{
if(vis[k]) temp.push_back(make_pair(make_pair(stk[i],j),k));
}
}
for(int j:v1[stk[i]]) vis[j]=0;
}
for(auto j:temp)
{
if(j.first.first>j.first.second) swap(j.first.first,j.first.second);
if(j.first.second>j.second) swap(j.first.second,j.second);
if(j.first.first>j.first.second) swap(j.first.first,j.first.second);
mp[j]++;
}
for(int i=1;i<=cnt;i++) head[stk[i]]=0,in[stk[i]]=0,v1[stk[i]].clear();
}
for(int i=1;i<=m;i++) res+=mp[make_pair(make_pair(u[i],v[i]),w[i])];
cout<<res;
return 0;
}