首先我们强制要求几条待定价的边在MST中,建出MST
我们发现这个MST中原来的边是一定要被选上的,所以可以把点缩起来,搞成一棵只有$K$个点的树
然后$2^K$枚举每条边在不在最终的MST中,让在最终MST中的待定价的边尽量大,只需要在Kruskal的时候暴力更新每条边的定价即可
时间复杂度$O(m * logm + 2^K * K^2)$
1 /************************************************************** 2 Problem: 3206 3 User: rausen 4 Language: C++ 5 Result: Accepted 6 Time:8040 ms 7 Memory:8232 kb 8 ****************************************************************/ 9 10 #include <cstdio> 11 #include <algorithm> 12 13 using namespace std; 14 typedef long long ll; 15 const int N = 1e5 + 5; 16 const int M = 3e5 + 5; 17 const int K = 25; 18 const int inf = 1e9; 19 20 inline int read(); 21 22 struct Edge { 23 int x, y, v; 24 25 inline void get(int f) { 26 x = read(), y = read(); 27 if (f) v = read(); 28 } 29 30 inline bool operator < (const Edge &E) const { 31 return v < E.v; 32 } 33 } E[M], Ek[K], s[K]; 34 35 struct edge { 36 int next, to; 37 edge() {} 38 edge(int _n, int _t) : next(_n), to(_t) {} 39 } e[K << 1]; 40 41 int first[N], tot; 42 43 struct tree_node { 44 int fa, dep, mn; 45 ll v, sum; 46 } tr[N]; 47 48 int n, m, k, S, top; 49 int fa[2][N]; 50 int root[K], cnt_root; 51 int u[K]; 52 ll ans; 53 54 inline void Add_Edges(int x, int y) { 55 e[++tot] = edge(first[x], y), first[x] = tot; 56 e[++tot] = edge(first[y], x), first[y] = tot; 57 } 58 59 int find(int x, int f) { 60 return x == fa[f][x] ? x : fa[f][x] = find(fa[f][x], f); 61 } 62 63 #define y e[x].to 64 void dp(int p) { 65 int x; 66 tr[p].sum = tr[p].v; 67 for (x = first[p]; x; x = e[x].next) 68 if (y != tr[p].fa) { 69 tr[y].dep = tr[p].dep + 1, tr[y].fa = p; 70 dp(y); 71 tr[p].sum += tr[y].sum; 72 } 73 } 74 #undef y 75 76 ll work() { 77 static int i, x, y, p; 78 static ll res; 79 for (tot = 0, i = 1; i <= k + 1; ++i) { 80 p = root[i]; 81 fa[0][p] = p; 82 first[p] = tr[p].fa = 0, tr[p].mn = inf; 83 } 84 for (i = 1; i <= k; ++i) 85 if (u[i]) { 86 x = find(Ek[i].x, 0), y = find(Ek[i].y, 0); 87 if (x == y) return 0; 88 fa[0][x] = y; 89 Add_Edges(Ek[i].x, Ek[i].y); 90 } 91 for (i = 1; i <= k; ++i) { 92 x = find(s[i].x, 0), y = find(s[i].y, 0); 93 if (x != y) fa[0][x] = y, Add_Edges(s[i].x, s[i].y); 94 } 95 dp(S); 96 for (i = 1; i <= k; ++i) { 97 x = s[i].x, y = s[i].y; 98 if (tr[x].dep < tr[y].dep) swap(x, y); 99 while (tr[x].dep != tr[y].dep) 100 tr[x].mn = min(tr[x].mn, s[i].v), x = tr[x].fa; 101 while (x != y) { 102 tr[x].mn = min(tr[x].mn, s[i].v); 103 tr[y].mn = min(tr[y].mn, s[i].v); 104 x = tr[x].fa, y = tr[y].fa; 105 } 106 } 107 #define x Ek[i].x 108 #define y Ek[i].y 109 for (res = 0, i = 1; i <= k; ++i) 110 if (u[i]) 111 res += tr[x].dep > tr[y].dep ? tr[x].mn * tr[x].sum : tr[y].mn * tr[y].sum; 112 #undef x 113 #undef y 114 return res; 115 } 116 117 void dfs(int p) { 118 if (p == k + 1) { 119 ans = max(ans, work()); 120 return; 121 } 122 u[p] = 0, dfs(p + 1); 123 u[p] = 1, dfs(p + 1); 124 } 125 126 int main() { 127 int i, x, y; 128 n = read(), m = read(), k = read(); 129 for (i = 1; i <= m; ++i) E[i].get(1); 130 for (i = 1; i <= k; ++i) Ek[i].get(0); 131 for (i = 1; i <= n; ++i) tr[i].v = read(); 132 sort(E + 1, E + m + 1); 133 for (i = 1; i <= n; ++i) fa[0][i] = fa[1][i] = i; 134 135 for (i = 1; i <= k; ++i) 136 fa[0][find(Ek[i].x, 0)] = find(Ek[i].y, 0); 137 #define x E[i].x 138 #define y E[i].y 139 for (i = 1; i <= m; ++i) 140 if (find(x, 0) != find(y, 0)) 141 fa[0][find(x, 0)] = fa[0][find(y, 0)], fa[1][find(x, 1)] = fa[1][find(y, 1)]; 142 #undef x 143 #undef y 144 S = find(1, 1); 145 for (i = 1; i <= n; ++i) 146 if (find(i, 1) != i) tr[find(i, 1)].v += tr[i].v; 147 else root[++cnt_root] = i; 148 #define x Ek[i].x 149 #define y Ek[i].y 150 for (i = 1; i <= k; ++i) 151 x = find(x, 1), y = find(y, 1); 152 #undef x 153 #undef y 154 #define x E[i].x 155 #define y E[i].y 156 for (i = 1; i <= m; ++i) 157 x = find(x, 1), y = find(y, 1); 158 for (i = 1; i <= m; ++i) 159 if (find(x, 1) != find(y, 1)) 160 s[++top] = E[i], fa[1][find(x, 1)] = find(y, 1); 161 #undef x 162 #undef y 163 dfs(1); 164 printf("%lld\n", ans); 165 return 0; 166 } 167 168 inline int read() { 169 static int x; 170 static char ch; 171 x = 0, ch = getchar(); 172 while (ch < '0' || '9' < ch) 173 ch = getchar(); 174 while ('0' <= ch && ch <= '9') { 175 x = x * 10 + ch - '0'; 176 ch = getchar(); 177 } 178 return x; 179 }
By Xs酱~ 转载请说明
博客地址:http://www.cnblogs.com/rausen