「2019 集训队互测 Day 1」最短路径 题解

「2019 集训队互测 Day 1」最短路径 题解

题目传送门

算法标签: 分治,ntt。

这题主要考察了对于分治的应用。

首先考虑最简单的“树”的情况。很容易想到,可以点分治+卷积实现。

然后只剩下环的情况了。

设环上的第距离环上的第i个点距离为j的点的个数为\([x^j]f_i\)

设环长为\(len\)

我们将从任意一个位置破环成链。然后再复制一份贴到后面。

则我们要算出:\(\sum f_i\times f_j\times x^{j-i},(j>i,0\leq i<len,j-i\geq \lfloor(len-1)/2\rfloor)\)。特殊的我们还需要处理一下偶数的情况。

然后我们可以将那个长度为len+len的东西分成4/5段,每段为\(\lfloor(len-1)/2\rfloor\),则段内部一定满足\(j-i\geq \lfloor(len-1)/2\rfloor\)这个条件,所以直接分治算就可以了。

然后考虑两段之间的贡献,同样类似dp决策单调性那样分治就可以了。

code:

这个代码为了减小常数,采用带权分治。

/*
{
######################
#       Author       #
#        Gary        #
#        2021        #
######################
*/
#include <bits/stdc++.h>
#define rb(a,b,c) for(int a=b;a<=c;++a)
#define rl(a,b,c) for(int a=b;a>=c;--a)
#define LL long long
#define IT iterator
#define PB push_back
#define II(a,b) make_pair(a,b)
#define FIR first
#define SEC second
#define FREO freopen("check.out","w",stdout)
#define rep(a,b) for(int a=0;a<b;++a)
#define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
#define random(a) rng()%a
#define ALL(a) a.begin(),a.end()
#define POB pop_back
#define ff fflush(stdout)
#define fastio ios::sync_with_stdio(false)
#define check_min(a,b) a=min(a,b)
#define check_max(a,b) a=max(a,b)
using namespace std;
const int INF = 0x3f3f3f3f;
typedef pair<int, int> mp;
inline int read() {
    int x = 0;
    char ch = getchar();

    while (ch < '0' || ch > '9') {
        ch = getchar();
    }

    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }

    return x;
}
// NTT template begins
int MOD = 998244353;
int g = 3;
int len;
int rev[1 << 19];
void butterfly(vector<int> &v) {
    rep(i, len) {
        rev[i] = rev[i >> 1] >> 1;

        if (i & 1)
            rev[i] |= len >> 1;
    }

    rep(i, len) if (rev[i] > i)
        swap(v[i], v[rev[i]]);
}
int quick(int A, int B) {
    if (B == 0)
        return 1;

    int  tmp = quick(A, B >> 1);
    tmp = 1ll * tmp * tmp % MOD;

    if (B & 1)
        tmp = 1ll * tmp * A % MOD;

    return tmp;
}
int inv(int x) {
    return quick(x, MOD - 2);
}
vector<int> ntt(vector<int> v, int ty) {
    for (auto &it : v) {
        it %= MOD;
    }

    butterfly(v);
    vector<int> nex;

    for (int l = 2; l <= len; l <<= 1) {
        nex.clear();
        nex.resize(len);
        int step = quick(g, (MOD - 1) / l);

        if (ty == -1)
            step = inv(step);

        for (int j = 0; j < len; j += l) {
            int now = 1;

            for (int k = 0; k < l / 2; ++k) {
                int A, B;
                A = v[j + k];
                B = v[j + l / 2 + k];
                B = 1ll * now * B % MOD;
                nex[j + k] = (A + B) % MOD;
                nex[j + k + l / 2] = (A - B + MOD) % MOD;
                now = 1ll * now * step % MOD;
            }
        }

        v = nex;
    }

    return v;
}
void getlen(int x) {
    len = 1;

    while (len < x) {
        len <<= 1;
    }
}
vector<int> mul(vector<int> A, vector<int> B) {
    getlen(A.size() + B.size());
    A.resize(len);
    B.resize(len);
    A = ntt(A, 1);
    B = ntt(B, 1);
    rep(i, len) A[i] = 1ll * A[i] * B[i] % MOD;
    A = ntt(A, -1);
    int iv = inv(len);
    rep(i, len) {
        A[i] = 1ll * A[i] * iv % MOD;
    }

    while (!A.empty() && A.back() == 0)
        A.pop_back();

    return A;
}
void add(vector<int> &A, vector<int> B) {
    if (A.size() < B.size())
        A.resize(B.size());

    rep(i, B.size()) {
        (A[i] += B[i]) %= MOD;
    }
}
vector<int> right_shift(vector<int> A, int x) {
    reverse(ALL(A));
    rb(i, 1, x) A.PB(0);
    reverse(ALL(A));
    return A;
}
//NTT template ends
const int MAXN = 1e5 + 233;
int on_cycle[MAXN];
int n, k;
vector<int> gra[MAXN];
int siz[MAXN];
int sz = 0;
int anslen;
int get_centroid(int now, int fa = -1) {
    siz[now] = 1;
    int w = -INF;

    for (auto it : gra[now])
        if (it != fa && !on_cycle[it]) {
            int ret = get_centroid(it, now);

            if (ret)
                return ret;

            siz[now] += siz[it];
            check_max(w, siz[it]);
        }

    check_max(w, sz - siz[now]);

    if (w <= sz / 2 + 3) {
        return now;
    }

    return 0;
}
vector<int> cycle;
bool ok = 0;
bool vis[MAXN];
stack<int> sta;
void findcycle(int now, int pre = -1) {
    if (cycle.size())
        return;

    vis[now] = true;
    sta.push(now);

    for (auto it : gra[now])
        if (it != pre) {
            if (vis[it]) {
                int Now;

                do {
                    Now = sta.top(), sta.pop();
                    cycle.PB(Now);
                } while (Now != it);

                return ;
            }

            findcycle(it, now);

            if (cycle.size())
                return;
        }

    sta.pop();
}
vector<int> f[MAXN * 2];
vector<int> ret;
void calc(vector<int> &v, int now, int depth = 0, int pre = -1) {
    if (v.size() <= depth)
        v.resize(depth + 1);

    v[depth]++;

    for (auto it : gra[now])
        if (it != pre && !on_cycle[it]) {
            calc(v, it, depth + 1, now);
        }
}
void getsize(int now, int pre = -1) {
    sz++;

    for (auto it : gra[now])
        if (!on_cycle[it] && it != pre)
            getsize(it, now);
}
void get(int now) {
    sz = 0;
    getsize(now);
    now = get_centroid(now);
    bool pre = on_cycle[now];
    on_cycle[now] = 1;
    vector<int> presum(1, 1);

    for (auto it : gra[now])
        if (!on_cycle[it]) {
            get(it);
            vector<int> tmp;
            calc(tmp, it);
            add(ret, mul(presum, right_shift(tmp, 1)));
            add(presum, right_shift(tmp, 1));
        }

    on_cycle[now] = pre;
}
int to;
void solve(int ansl, int ansr, int l, int r) {
    check_min(ansr, anslen);

    if (ansl >= ansr || l >= r)
        return;

    vector<int> lp, rp;
    rb(i, ansl, ansr - 1) {
        int st = ansr - 1 - i;

        if (st + f[i].size() > lp.size())
            lp.resize(st + f[i].size());

        rep(j, f[i].size()) {
            (lp[j + st] += f[i][j]) %= MOD;
        }
    }
    rb(i, l, r - 1) {
        int st = i - l;

        if (st + f[i].size() > rp.size())
            rp.resize(st + f[i].size());

        rep(j, f[i].size()) {
            (rp[j + st] += f[i][j]) %= MOD;
        }
    }
    vector<int> tmp = mul(lp, rp);
    int gap = l - ansr + 1;
    rep(i, tmp.size()) {
        (ret[i + gap] += tmp[i]) %= MOD;
    }
}
void divc(int ansl, int ansr, int l, int r) {
    if (l == r - 1) {
        solve(l - to, ansr, l, r);
        return ;
    }

    int mid = (l + r) >> 1;
    int ansmid = mid - to;
    divc(ansmid, ansr, mid, r);
    divc(ansl, ansmid, l, mid);
    solve(ansmid, ansr, l, mid);
}
void div1(int l, int r) {
    if (l == r - 1)
        return ;

    mp best = {INF, INF};
    int tot = 0;
    rb(i, l, r - 1) tot += f[i].size();
    tot /= 2;
    rb(i, l, r - 1) {
        tot -= f[i].size();
        check_min(best, II(abs(tot), i));
    }
    int mid = best.second + 1;
    solve(l, mid, mid, r);
    div1(l, mid);
    div1(mid, r);
}
main() {
    //  freopen("sub35.in","r",stdin);
    //  scanf("%d%d",&n,&k);
    n = read();
    k = read();
    rb(i, 1, n) {
        int u, v;
        //      scanf("%d%d",&u,&v);
        u = read();
        v = read();

        if (u == v) {
            ok = true;
            continue;
        }

        gra[u].PB(v), gra[v].PB(u);
    }
    int rest = 0;
    ret.resize(n + 1);

    if (ok) {
        get(1);
    } else {
        findcycle(1);

        for (auto it : cycle)
            on_cycle[it] = true;

        int now = 0;

        for (auto it : cycle) {
            on_cycle[it] = false;
            calc(f[now++], it), get(it);
            on_cycle[it] = true;
        }

        anslen = cycle.size();
        cycle.resize(anslen + anslen);
        rep(i, anslen) cycle[i + anslen] = cycle[i], f[i + anslen] = f[i];
        to = anslen / 2;

        if (anslen & 1);
        else {
            to--;
            vector<int> tmp;

            rep(i, anslen) if (i < (i + to + 1) % anslen)
                add(tmp, mul(f[i], f[i + to + 1]));

            tmp = right_shift(tmp, to + 1);
            add(ret, tmp);
        }

        if (to) {
            vector<mp> each;
            int now = 0;

            while (now < anslen + anslen) {
                int nex = min(anslen + anslen, now + to);
                each.PB(II(now, nex));
                div1(now, nex);
                now = nex;
            }

            rb(i, 1, each.size() - 1) {
                divc(each[i - 1].FIR, each[i - 1].SEC, each[i].FIR, each[i].SEC);
            }
        }
    }

    rb(i, 0, n) {
        rest += 1ll * ret[i] * quick(i, k) % MOD;

        if (rest >= MOD)
            rest -= MOD;
    }
    rest = 1ll * rest * inv(1ll * n * (n - 1) / 2 % MOD) % MOD;
    cout << rest << endl;
    return 0;
}
posted @ 2021-01-22 12:20  WWW~~~  阅读(131)  评论(0编辑  收藏  举报