JZOI 4311 统一天下
题目描述:
从前有两个国家$W_1$, $W_2$。国家$W_i$($i∈ \{1,2\})$有$n_i$座城市,这$n_i$座城市由$n_i-1$条双向道路相连。任意一个国家的内部都是连通的。一个国家的两个点之间存在唯一的最短路,两个点的距离是这条最短路上边的数目。(也可以理解为道路的长度均为$1$。) 注意$W_1$和$W_2$并不联通。
有一天,$W_1$的军队在小H的带领下一举攻破了$W_2$,建立了大一统的政权$W$。为了使全国连通,小H决定在$W_1$的某个点和$W_2$的某个点之间新修建一条路,使得$n=n_1+n_2$个点两两组成的$\dfrac{n(n-1)}{2}$个点对的距离和最小。
小H请你帮帮他,算一下这个最小距离。
题解:
首先,官方题解有大大的锅!!! 他给的式子居然是错的!!
吐槽完了,进入正题。
不妨将两个国家分开来看。设$dp_{(1,i)}$表示一个图中,所有的点到$i$点的距离和,$g_1$为一个图中所有的点的相互间距离的和。
那么,我们要求的柿子就是:
$dp_{(1,i)} * n_2 + dp_{(2,j)} * n_1 + n_1 * n_2 + g_1 + g_2$
其中,$dp_{(1,i)} * n_2$表示一张图的所有点连向另一张图的每个点的距离和。注意到连接后会新加入一条边,所以加上$n_1 * n_2$,即该边会被经过这么多次。
最后$g_1$和$g_2$即为在两边的原图中内部配对。
考虑如何求$dp_{(1,i)}$: 一遍$dfs$,我们可以直接求出$dp{(1,1)}$, 然后有递推柿:
$dp_{(1,son)} = dp{(1, father)} + n - siz_{son} * 2$ 其中,$siz_{son}$代表儿子的子树大小(包括自己)。
对于$g_1$, 直接将所有的$dp_{(1,i)}$相加然后除以$2$即可。
最后考虑如何求答案: 不难发现,$i, j$的贡献是分开的,没有联系、所以找到最小的$dp_{(1,i)}$ 和最小的$dp_{(2,j)}$,然后算出上面的柿子即可。
#include <bits/stdc++.h> using namespace std; #define int long long const int N = (int)6e5 + 10; template <class T> inline void read(T& a){ T x = 0, s = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); } a = x * s; return ; } struct node{ public: int v, next; public: node(int v = 0, int next = 0){ this -> v = v; this -> next = next; return ; } } t[N << 1]; int f[N]; int bian = 0; inline void add(int u, int v){ t[++bian] = node(v, f[u]), f[u] = bian; t[++bian] = node(u, f[v]), f[v] = bian; return ; } int n1, n2; int dp[N]; int siz[N], dis[N]; int G[2]; int tot; #define v t[i].v void dfs(int now, int father){ siz[now] = 1; for(int i = f[now]; i; i = t[i].next){ if(v != father){ dis[v] = dis[now] + 1; tot += dis[v]; dfs(v, now); siz[now] += siz[v]; } } return ; } void dfs2(int now, int father, int n){ for(int i = f[now]; i; i = t[i].next){ if(v != father){ dp[v] = dp[now] + n - siz[v] * 2; dfs2(v, now, n); } } return ; } #undef v signed main(){ // freopen("hh.txt", "r", stdin); read(n1), read(n2); for(int i = 1 ; i < n1; i++){ int x, y; read(x), read(y); add(x, y); } for(int i = 1; i < n2; i++){ int x, y; read(x), read(y); add(x + n1, y + n1); } dfs(1, 0); dp[1] = tot; tot = 0; dfs(1 + n1, 0); dp[n1 + 1] = tot; dfs2(1, 0, n1); dfs2(n1 + 1, 0, n2); int tot1 = 0, tot2 = 0; int minn1 = (int)1e18, minn2 = (int)1e18; for(int i = 1 ; i <= n1; i++) tot1 += dp[i], minn1 = min(minn1, dp[i]); for(int i = n1 + 1;i <= n1 + n2; i++) tot2 += dp[i], minn2 = min(minn2, dp[i]); G[0] = tot1 / 2, G[1] = tot2 / 2; int ans = minn1 * n2 + minn2 * n1 + n1 * n2 + G[0] + G[1]; cout << ans << endl; return 0; }