P1395 会议[树]

会议

题目描述

有一个村庄居住着 \(n\) 个村民,有 \(n-1\) 条路径使得这 \(n\) 个村民的家联通,每条路径的长度都为 \(1\)。现在村长希望在某个村民家中召开一场会议,村长希望所有村民到会议地点的距离之和最小,那么村长应该要把会议地点设置在哪个村民的家中,并且这个距离总和最小是多少?若有多个节点都满足条件,则选择节点编号最小的那个点。

输入格式

第一行,一个数 \(n\),表示有 \(n\) 个村民。

接下来 \(n-1\) 行,每行两个数字 \(a\)\(b\),表示村民 \(a\) 的家和村民 \(b\) 的家之间存在一条路径。

输出格式

一行输出两个数字 \(x\)\(y\)

\(x\) 表示村长将会在哪个村民家中举办会议。

\(y\) 表示距离之和的最小值。

样例 #1

样例输入 #1

4
1 2 
2 3 
3 4

样例输出 #1

2 4

提示

数据范围

对于 \(70\%\) 数据 \(n \le 10^3\)

对于 \(100\%\) 数据 \(n \le 5 \times 10^4\)

Solution

容易发现题设是一颗树,那么先把问题分解成与以各节点为根的子树的子问题。显然,任意节点的距离总和可以被分解成两部分:它的子树到它的距离总和和整个树中其它节点到它的距离总和,而后者可以从更大的子树中计算得到。

\(mulLen[i]\) 表示以 \(i\) 为根的子树到 \(i\) 的距离之和,\(subPoint[i]\) 表示以 \(i\) 为根的子树的节点数量(含根),可以得到转移方程。

\[subPoint[i] = \sum_{i\in subTree(u)} subPoint[v] \]

\[mulLen[i] = \sum_{i\in subTree(u)}mulLen[v] + subPoint[v] \]

那么设 \(dp[i]\) 为所有节点到 \(i\) 的距离总和,不妨假设整棵树的节点数为 \(n\),则有

\[dp[i] = dp[father[i]] + n - subPoint[i] * 2 \]

这可以根据这节点的父节点和它之间的边的关系推出。

一次回溯求,一次递归求,就行了。

Code

  • 这是一开始不知道哪根筋抽了写的 \(O(n^2)\) 算法,T 2个点,随后发现 \(calc\) 其实可以从上往下计算变成 \(O(n)\)
#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#define MAXN 50010
#define INF 0x7fffffff
using namespace std;
struct Ver {
    int u, v, next;
}ver[MAXN << 1];
int head[MAXN], fa[MAXN], cnt;
int mulLen[MAXN], subPoint[MAXN];
void add(int u, int v)
{
    ver[++cnt].next = head[u];
    ver[cnt].u = u; ver[cnt].v = v;
    head[u] = cnt;
}

void process(int u)
{
    for(int i = head[u]; i; i = ver[i].next) {
        int v = ver[i].v;
        if(fa[u] == v) continue;
        fa[v] = u;
        process(v);
        subPoint[u] += subPoint[v];
        mulLen[u] += mulLen[v] + subPoint[v];
    }
}

int calc(int u)
{
    if(fa[u] == 0) return mulLen[u];
    int sum = 0, cnt = 0;
    sum += mulLen[u];
    while(fa[u]) {
        int f = fa[u]; cnt++;
        sum += mulLen[f] - mulLen[u] - subPoint[u] + cnt;
        for(int i = head[f]; i; i = ver[i].next) {
            int v = ver[i].v;
            if(v == u || v == fa[f]) continue;
            sum += subPoint[v] * cnt;
        }
        u = fa[u];
    }
    return sum;
}
int main()
{
    int n;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) subPoint[i] = 1;
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v); add(v, u);
    }
    process(1);
    int min = INF, ans = 1;
    for(int i = 1; i <= n; i++) {
        int temp = calc(i);
        if(min > temp) {
            min = temp;
            ans = i;
        }
    }
    printf("%d %d\n", ans, min);
    return 0;
}
  • 这个是正解,\(clac\) 变成 \(O(n)\) 了。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#define MAXN 50010
#define INF 0x7fffffff
using namespace std;
struct Ver {
    int u, v, next;
}ver[MAXN << 1];
int head[MAXN], fa[MAXN], cnt;
int mulLen[MAXN], subPoint[MAXN], dp[MAXN];
void add(int u, int v)
{
    ver[++cnt].next = head[u];
    ver[cnt].u = u; ver[cnt].v = v;
    head[u] = cnt;
}

void process(int u)
{
    for(int i = head[u]; i; i = ver[i].next) {
        int v = ver[i].v;
        if(fa[u] == v) continue;
        fa[v] = u;
        process(v);
        subPoint[u] += subPoint[v];
        mulLen[u] += mulLen[v] + subPoint[v];
    }
}

void calc(int u)
{
    for(int i = head[u]; i; i = ver[i].next) {
        int v = ver[i].v;
        if(fa[u] == v) continue;
        dp[v] = dp[u] + subPoint[1] - subPoint[v] * 2;
        calc(v);
    }
}
    
int main()
{
    int n;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) subPoint[i] = 1;
    for(int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v); add(v, u);
    }
    process(1);
    dp[1] = mulLen[1];
    calc(1);
    int ans = INF, minnode = 1;
    for(int i = 1; i <= n; i++)
        if(ans > dp[i]) {
            ans = dp[i];
            minnode = i;
        }
    printf("%d %d\n", minnode, ans);
    return 0;
}
posted @ 2022-10-06 16:18  DarkValkyrie  阅读(56)  评论(0编辑  收藏  举报