CF1089D Distance Sum

传送门

tourist 出的绝世好题


思路

首先,考虑一个范围更广的问题:

\[\sum_{u=1}^{n-1}\sum_{v=u+1}^n w_uw_vd(u,v) \]

\(w_u\) 表示 \(u\) 的点权,\(d(u,v)\) 表示 \(u,~v\) 之间最短路

显然原问题就是 \(w_u\) 都为 \(1\) 的情况

(为啥不简化题目还要复杂化?我知道你很急但你先别急

对于图中度数为 \(1\) 的点 \(u\),它显然只有唯一一条连边 \((u,v)\),那么如果要到 \(u\),显然必须要经过这条边

因此我们可以考虑先将这条边进行贡献:\(w_u\times (n-w_u)\)\(n\) 是总点数);然后我们将 \(w_u\) 加到 \(w_v\) 上,将边 \((u,v)\) 删去

重复上诉步骤,我们最后会得到若干个环组成的连通图 \(G\)

\(G\) 中,我们找出所有度数大于 \(2\) 的点作为特殊点(如果没有,说明 \(G\) 就是一个简单环,那么任选两个点),由于边数 \(m\le n+42\),那么特殊点的个数不会超过 \(84\)

我们考虑选择两个相邻的特殊点 \(l, r\),它们路径上的点记为 \(p_1,~p_2,~p_3,...,p_k\)

对于 \(p_i\),我们计算 \(p_j,~j>i\) 的点对它的贡献:显然有一部分会直接走路径 \((l,r)\) 内的路,有一部分会走到 \(r\) 后,通过另一部分的半环,走到 \(l\) ,再到 \(p_i\)

对于不在路径 \((l,r)\) 上的点(也就是不是 \(p\) 中的点),我们计算所有 \(p\) 的贡献:显然有一部分会经过 \(l\),另一部分会经过 \(r\),然后到达终点

我们可以记录两个前缀和数组 \(s_1,s_2\)

\[s_1[x]=\sum_{i=1}^x w_{p[i]} \]

\[s_2[x]=\sum_{i=1}^x w_{p[i]}\times i \]

然后就可以计算上面要的贡献

具体实现时,为了避免讨论,会让统计的答案是正确答案的两倍,记得输出时除以 \(2\)


代码

#include<iostream>
#include<fstream>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<queue>
#include<map>
#include<set>
#include<bitset>
#define LL long long
#define FOR(i, x, y) for(int i = (x); i <= (y); i++)
#define ROF(i, x, y) for(int i = (x); i >= (y); i--)
#define PFOR(i, x) for(int i = he[x]; i; i = r[i].nxt)
inline int rd()
{
    int sign = 1, re = 0; char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') sign = -1; c = getchar();}
    while('0' <= c && c <= '9'){re = re * 10 + (c - '0'); c = getchar();}
    return sign * re;
}
int n, m, k; LL ans;
std::set<int> lk[100005];
std::queue<int> q;
int sz[100005];
std::vector<int> r[100005];
int SZ[100005], id[100005], dcnt;
int dis[90][100005];
inline void bfs(int st)
{
    FOR(i, 1, n) dis[st][i] = 1e9;
    q.push(st); dis[st][st] = 0;
    while(!q.empty())
    {
        int now = q.front(); q.pop();
        for(int to : r[now]) if(dis[st][to] > dis[st][now] + 1)
            dis[st][to] = dis[st][now] + 1,
            q.push(to);
    }
}
std::bitset<100005> vis, inlk;
std::vector<int> p;
LL s1[100005], s2[100005];
inline void calc(int st, int x)
{
    p.clear(); vis[st] = vis[x] = 1;
    p.emplace_back(st), p.emplace_back(x);
    while(x > k)
        for(int to : r[x]) if(p[p.size() - 2] != to)
        {
            p.emplace_back(to), x = to, vis[x] = 1;
            break;
        }
    int L = p[0], R = p.back(), len = p.size() - 1, csz = len + dis[L][R];
    // 小半环时 csz 是两倍 (L, R);大半环时 csz 是 L, R 所在环的大小
    FOR(i, 1, len - 1) s1[i] = s1[i - 1] + SZ[p[i]], s2[i] = s2[i - 1] + 1ll * i * SZ[p[i]];
    // 计算半环内的点之间相互的贡献
    LL sum = 0;
    FOR(i, 1, len - 1)
    {
        int t = std::min(len - 1, i + (csz >> 1));
        LL t1 = (s2[t] - s2[i - 1]) - 1ll * i * (s1[t] - s1[i - 1]);
        LL t2 = 1ll * (i + csz) * (s1[len - 1] - s1[t]) - (s2[len - 1] - s2[t]);
        sum += (t1 + t2) * SZ[p[i]];
    }
    ans += sum << 1;
    // 计算半环内的点对半环外的点的贡献
    FOR(i, 1, len - 1) inlk[p[i]] = 1;
    FOR(i, 1, n) if(!inlk[i])
    {
        int dl = dis[L][i], dr = dis[R][i];
        int t = std::min(len - 1, (dr - dl + len) >> 1);
        LL t1 = s2[t] + dl * s1[t];
        LL t2 = (dr + len) * (s1[len - 1] - s1[t]) - (s2[len - 1] - s2[t]);
        ans += (t1 + t2) * SZ[i];
    }
    FOR(i, 1, len - 1) inlk[p[i]] = 0;
}
signed main()
{
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = rd(), m = rd();
    FOR(i, 1, m)
    {
        int u = rd(), v = rd();
        lk[u].insert(v), lk[v].insert(u);
    }
    FOR(i, 1, n)
    {
        sz[i] = 1;
        if(lk[i].size() == 1) q.push(i);
    }
    while(!q.empty())
    {
        int now = q.front(); q.pop();
        for(int to : lk[now])
        {
            lk[to].erase(now);
            ans += 2ll * sz[now] * (n - sz[now]);
            sz[to] += sz[now];
            if(lk[to].size() == 1) q.push(to);
        }
        lk[now].clear();
    }
    FOR(i, 1, n) if(lk[i].size() > 2)
        id[i] = ++dcnt;
    k = dcnt ? dcnt : 2;
    FOR(i, 1, n) if(lk[i].size() == 2)
        id[i] = ++dcnt;
    FOR(i, 1, n)
    {
        for(int j : lk[i]) r[id[i]].emplace_back(id[j]);
        lk[i].clear(), SZ[id[i]] = sz[i];
    }
    n = dcnt;
    FOR(i, 1, k)
    {
        bfs(i);
        FOR(j, 1, n) ans += 1ll * dis[i][j] * SZ[i] * SZ[j];
    }
    FOR(i, 1, k) for(int to : r[i]) if(to > k && !vis[to])
        calc(i, to);
    printf("%lld", ans >> 1);
    return 0;
}
posted @ 2022-09-23 15:20  zuytong  阅读(17)  评论(0编辑  收藏  举报