E. Sergey and Subway

比赛时候写复杂了……

我写的是 计算每个节点树内所有点到某个点的距离和。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 50;
vector<int> g[maxn];
int son[maxn];
ll d[maxn]; ///树内所有点到某个点的距离和
int odd[maxn];
int even[maxn];
void dfs1(int u, int fa)
{
    d[u] = 0;
    son[u] = 0;
    son[u]++;
    even[u] = 1;
    for(int i = 0; i < (int)g[u].size(); i++)
    {
        int v = g[u][i];
        if(v == fa) continue;
        dfs1(v, u);
        son[u] += son[v];
        even[u] += odd[v];
        d[u] += d[v] + son[v];
    }
    odd[u] = son[u] - even[u];
   // printf("%d %I64d\n", u, d[u]);
}
ll ans = 0;
void dfs2(int u, int fa)
{
    for(int i = 0; i < (int)g[u].size(); i++)
    {
        int v = g[u][i];
        if(v == fa) continue;
        d[v] = d[v] + (d[u] - d[v] - (ll)son[v]) + (ll)(son[u] - son[v]);
        odd[v] += (son[u] - son[v] - odd[u] + even[v]);
        son[v] = son[u];
        ans += ((d[v] + odd[v]) / 2LL);
        dfs2(v, u);
    }
}
int main()
{
    int n; scanf("%d", &n);
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v; scanf("%d %d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1, 0);
    dfs2(1, 0);
    ans += (d[1] + odd[1]) / 2LL;
    printf("%lld\n", ans / 2LL);
    return 0;
}
Code

 实际上这题只需要计算树上任意两点的距离和,枚举每条边的贡献就行了。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 50;
vector<int> g[maxn];
int n;
ll ans, odd, even;
int dfs(int u, int fa, int dep)
{
    if(dep % 2) odd++;
    else even++;
    int son = 1; ///统计子树大小
    for(int i = 0; i < (int)g[u].size(); i++)
    {
        int v = g[u][i];
        if(v == fa) continue;
        int sing = dfs(v, u, dep + 1);
        ans += (ll)sing * (n - sing); ///枚举每条边的贡献
        son += sing;
    }
    return son;
}
int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v; scanf("%d %d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0, 0);
    ans = (ans + (ll)odd * even) / 2;
    printf("%lld\n", ans);
    return 0;
}
Code

 

posted @ 2018-10-05 10:24  汪汪鱼  阅读(298)  评论(0编辑  收藏  举报