【bzoj5206】[Jsoi2017]原力 根号分治+STL-map
题目描述
一个原力网络可以看成是一个可能存在重边但没有自环的无向图。每条边有一种属性和一个权值。属性可能是R、G、B三种当中的一种,代表这条边上原力的类型。权值是一个正整数,代表这条边上的原力强度。原力技术的核心在于将R、G、B三种不同的原力融合在一起产生单一的、便于利用的原力。为了评估一个能源网络,JYY需要找到所有满足要求的三元环(首尾相接的三条边),其中R、G、B三种边各一条。一个三元环产生的能量是其中三条边的权值之积。
现在对于给出的原力网络,JYY想知道这个网络的总能量是多少。网络的总能量是所有满足要求三元环的能量之和。
输入
第一行包含两个正整数N、M。表示原力网络的总顶点个数和总边数。
接下来M行,每行包含三个正整数ui,vi,wi和一个字符ci。
表示编号ui和vi的顶点之间存在属性为ci权值为wi的一条边。
N≤50,000,M≤100,000,1≤?Wi≤10^6
输出
输出一行一个整数,表示这个原力网络的总能量模10^9+7的值
样例输入
4 6
1 2 2 R
2 4 3 G
4 3 5 R
3 1 7 G
1 4 11 B
2 3 13 B
样例输出
828
题解
根号分治+STL-map
看到这种根本没法写出什么玄学数据结构之类的,大概率就是根号分治了。
对于本题,由于边数只有 $m$ ,因此度数大于等于 $\sqrt m$ 的点只有 $O(\sqrt m)$ 个,我们称这样的点为大点,度数小于 $\sqrt m$ 的称为小点。
那么对于一个三元环:
如果三个点都是大点:这种情况下我们暴力枚举三个大点,求出是否有满足条件的三元环并加入到答案中即可。时间复杂度为 $O((\sqrt m)^3)=O(m\sqrt m)$ ;
如果三个点中有小点:这种情况下我们枚举每个小点和它的两条出边,判断这三个点是否有满足条件的三元环。此时,枚举第一条出边相当于枚举图中所有边,第二条出边是度数复杂度,而度数小于 $\sqrt m$ ,因此复杂度也是 $O(m\sqrt m)$ 的。注意这个过程需要保证不重不漏,因此只考虑枚举点为这三个点中编号最小的小点的答案。
那么如何判断是否有满足条件的三元环呢?我偷懒了使用STL-map判断两点之间有没有某颜色的边,复杂度上会多一个log。
时间复杂度 $O(m\sqrt m\log m)$ ,实际上跑得挺快的 然而在bz上还是倒数第一...
#include <map> #include <cmath> #include <cstdio> #define N 50010 #define mod 1000000007 using namespace std; typedef long long ll; struct data { int x , y , z; data() {} data(int a , int b , int c) {x = a , y = b , z = c;} bool operator<(const data &a)const {return x == a.x ? y == a.y ? z < a.z : y < a.y : x < a.x;} }; map<data , ll> mp; int head[N] , to[N << 2] , val[N << 2] , opt[N << 2] , next[N << 2] , cnt , d[N] , id[350] , tot; char str[5]; inline void add(int x , int y , int v , int c) { to[++cnt] = y , val[cnt] = v , opt[cnt] = c , next[cnt] = head[x] , head[x] = cnt; } int main() { int n , m , si , i , j , k , x , y , z , t; ll ans = 0; scanf("%d%d" , &n , &m) , si = (int)sqrt(m); for(i = 1 ; i <= m ; i ++ ) { scanf("%d%d%d%s" , &x , &y , &z , str); t = (str[0] == 'R' ? 1 : str[0] == 'G' ? 2 : 3); add(x , y , z , t) , add(y , x , z , t) , d[x] ++ , d[y] ++ ; (mp[data(x , y , t)] += z) %= mod , (mp[data(y , x , t)] += z) %= mod; } for(i = 1 ; i <= n ; i ++ ) if(d[i] >= si) id[++tot] = i; for(i = 1 ; i <= tot ; i ++ ) for(j = 1 ; j <= tot ; j ++ ) for(k = 1 ; k <= tot ; k ++ ) ans = (ans + mp[data(id[i] , id[j] , 1)] * mp[data(id[i] , id[k] , 2)] % mod * mp[data(id[j] , id[k] , 3)]) % mod; for(i = 1 ; i <= n ; i ++ ) if(d[i] < si) for(j = head[i] ; j ; j = next[j]) if(d[to[j]] >= si || to[j] > i) for(k = next[j] ; k ; k = next[k]) if(opt[k] != opt[j] && (d[to[k]] >= si || to[k] > i)) ans = (ans + mp[data(to[j] , to[k] , 6 - opt[j] - opt[k])] * val[j] % mod * val[k]) % mod; printf("%lld\n" , ans); return 0; }