Codeforces 733F Drivers Dissatisfaction
题意:有n个点,m条边,每条边有不满意度w[i],以及减小一个不满意度代价c[i],问给你s元用来减少代价,找到一个总不满意度最小的生成树,保证有解。(减少后的不满意度可以为负数)
思路:
显然所有的钱都应该用在生成树中c最小的那条边上
先求出以w[i]为权的最小生成树O(nlogn)
答案一定是在现在求出的最小生成树基础上换掉一条边 或 不变
把所有边进行替换尝试,找到换掉后的最优解。(可以只尝试c不大于mst中最小的c的边)
添加一条边后需要尝试换掉,与该边组成环后的最大权的边,
找最大权就是找新增边的两个节点到最近公共祖先的路径上的最大边
具体见代码
const int maxn = 2 * 100000 + 100; struct Edge { int id, u, v, w, c;//从u到v权为w bool operator < (const Edge& rhs) const { if (w != rhs.w) return w < rhs.w; return id < rhs.id; } }; vector<Edge> e; int n, m; int pa[maxn]; vector<pair<int, bool> > G[maxn]; bool vis[maxn]; //被选中的边 int find(int x) { return pa[x] != x ? pa[x] = find(pa[x]) : x; } int kruskal() { int min_c = INF; for (int i = 1; i <= n; i++) pa[i] = i; sort(e.begin(), e.end()); for (int i = 0; i < e.size(); i++) { int x = find(e[i].u), y = find(e[i].v); if (x != y) { vis[e[i].id] = true; min_c = min(min_c, e[i].c); G[e[i].u].push_back(mp(i, 0)); G[e[i].v].push_back(mp(i, 1)); pa[x] = y; } } return min_c; } const int maxlog = 20; int fa[maxn]; // 父亲数组 int cost[maxn]; // 和父亲的费用 int L[maxn]; // 层次(根节点层次为0) struct LCA { int anc[maxn][maxlog]; // anc[p][i]是结点p的第2^i级父亲。anc[i][0] = fa[i] int maxcost[maxn][maxlog]; // maxcost[p][i]是i和anc[p][i]的路径上的最大费用 // 预处理,根据fa和cost数组求出anc和maxcost数组 void preprocess() { for(int i = 1; i <= n; i++) { anc[i][0] = fa[i]; maxcost[i][0] = cost[i]; for(int j = 1; (1 << j) <= n; j++) anc[i][j] = -1; } for(int j = 1; (1 << j) <= n; j++) { for(int i = 1; i <= n; i++) { if(anc[i][j-1] != -1) { int a = anc[i][j-1]; anc[i][j] = anc[a][j-1]; maxcost[i][j] = max(maxcost[i][j-1], maxcost[a][j-1]); } } } } // 求p到q的路径上的最大权 pii query(int p, int q) { int tmp, power, i; if(L[p] < L[q]) swap(p, q); //L[p] >= L[q] for(power = 1; (1 << power) <= L[p]; power++); power--; //(2^power <= L[p]中的最大的) int ans = -INF; for(int i = power; i >= 0; i--) { if (L[p] - (1 << i) >= L[q]) { ans = max(ans, maxcost[p][i]); p = anc[p][i]; } } if (p == q) return mp(ans, p); // LCA为p for(int i = power; i >= 0; i--) { if(anc[p][i] != -1 && anc[p][i] != anc[q][i]) { ans = max(ans, maxcost[p][i]); p = anc[p][i]; ans = max(ans, maxcost[q][i]); q = anc[q][i]; } } ans = max(ans, cost[p]); ans = max(ans, cost[q]); return mp(ans, fa[p]); // LCA为fa[p](它也等于fa[q]) } } lca; int w[maxn], c[maxn]; int all; void init() { scanf("%d%d", &n, &m); for (int i = 0; i < m; i++) { scanf("%lld", w + i); } for (int i = 0; i < m; i++) { scanf("%lld", c + i); } int u, v; for (int i = 0; i < m; i++) { scanf("%d%d", &u, &v); e.push_back((Edge){i, u, v, w[i], c[i]}); } scanf("%d", &all); } void bfs() { queue<int> q; q.push(1); fa[1] = 0; L[1] = 0; cost[1] = 0; while (!q.empty()) { int u = q.front(); q.pop(); for (auto i : G[u]) { int v = e[i.x].v; if (i.y) v = e[i.x].u; if (fa[u] == v) { continue; } q.push(v); fa[v] = u; L[v] = L[u] + 1; cost[v] = e[i.x].w; } } } int find_idx(int u, int v) //找边uv的编号 { if (u == -1) cout << "error!" << endl; for (int i = 0; i < G[u].size(); i++) { int to = e[G[u][i].x].v; if (G[u][i].y) to = e[G[u][i].x].u; if (to == v) return e[G[u][i].x].id; } } Edge find_change(const int min_c, LL& sub) { Edge ans = (Edge){-1}; for (auto i : e) { if (vis[i.id] || i.c > min_c) continue; LL max_wc = lca.query(i.u, i.v).x; LL new_sub = all / i.c - i.w + max_wc; if (new_sub > sub) { sub = new_sub; ans = i; } } return ans; } int find_delete(const Edge& ans) { pii father = lca.query(ans.u, ans.v); pii del(-1, -1); //删除 del for (int i = ans.u; i != father.y; i = fa[i]) { if (cost[i] == father.x) { del = mp(i, fa[i]); break; } } if (del.x == -1) { for (int i = ans.v; i != father.y; i = fa[i]) { if (cost[i] == father.x) { del = mp(i, fa[i]); break; } } } return find_idx(del.x, del.y); } void solve() { int min_c = kruskal(); /* cout << "MST:" <<endl; for (auto i : e) if (vis[i.id]) cout << i.id +1 << " "; cout << endl; */ bfs(); lca.preprocess(); LL sub = all / min_c; Edge ans = find_change(min_c, sub); LL sum = 0; for (auto i : e) { if (vis[i.id]) sum += i.w; } printf("%lld\n", sum - sub); if (ans.id == -1) { bool flag = true; for (auto i : e) { if (vis[i.id]) { if (flag && i.c == min_c) { printf("%d %d\n", i.id + 1, i.w - all / i.c); flag = false; } else printf("%d %d\n", i.id + 1, i.w); } } return; } int idx = find_delete(ans); for (auto i : e) { if (vis[i.id] && i.id != idx) { printf("%d %d\n", i.id + 1, i.w); } } printf("%d %d\n", ans.id + 1, ans.w - all/ans.c); } int main() { init(); solve(); return 0; }