「解题报告」[省选联考 2022] 填树

link

发现写正经省选题和写 ARC / CF 题感觉就是不一样。

感觉 AtCoder 和 CodeForces 出题风格和 CCF 出题风格区别确实还是很大。

还是补下题⑧

去年考场上这道题想到了可以把贡献写成分段函数的形式,但是剩下的我想直接暴力维护分段函数的多项式,然后就啥都不会了。

我去年不会插值这种东西吗?好像是根本没见过插值的套路?


最大值减最小值的限制可以转化成我们枚举一个值域 \([l, r]\) 满足 \(r - l = k\),并且 \(x_i \in [l, r]\)。但是注意到如果最大值减最小值小于 \(k\) 时,同一种方案会算多遍。考虑钦定最大值必须取到,枚举一个右端点 \(r\),那么我们可以用 \([r - k, r]\) 的答案减去 \([r - k, r - 1]\) 的答案,这样就能钦定最大值必须取到,就不会算重了。对所有的右端点求和,我们就相当于要求所有区间长为 \(k + 1\) 减去长度为 \(k\) 的答案。

先单独拿出长度为 \(k\) 的来看怎么做。我们仍然可以沿用上面的思想,枚举右端点,然后计算当前的答案是什么。这样我们可以很轻松的计算出每个点可以选的方案数去权值和,跑一遍树形 DP 即可。这样我们就得到了一个 \(O(nV)\) 的做法。

考虑优化枚举右端点的部分。我们可以将每个点的方案数 / 权值和写成分段函数的形式,容易发现每一段都是一个多项式函数,那么最后我们要求的答案就也是一个分段多项式函数。而这个多项式的次数是 \(O(n)\) 的。

那么我们可以考虑对于分段函数的每一段,只计算出 \(O(n)\) 个点值,求一下前缀和,然后就可以通过拉格朗日插值插出这一段的前缀和了。

段数是 \(O(n)\) 的,每一段要计算 \(O(n)\) 个点值,计算一次复杂度为 \(O(n)\),所以总复杂度就是 \(O(n^3)\)

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 805, P = 1000000007, B = 210;
int n, k;
long long l[MAXN], r[MAXN];
vector<int> e[MAXN];
int qpow(long long a, long long b) {
    int ans = 1;
    a %= P;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
long long F(int i, long long x) {
    return max(0ll, min(r[i], x) - max(l[i], x - k + 1) + 1);
}
long long G(int i, long long x) {
    long long L = max(l[i], x - k + 1), R = min(r[i], x);
    if (L > R) return 0;
    return 1ll * (L + R) * (R - L + 1) / 2 % P;
}
vector<long long> s;
long long f[MAXN], g[MAXN];
long long ans1, ans2, x;
long long a1[MAXN], a2[MAXN];
void dfs(int u, int pre) {
    long long fv = F(u, x), gv = G(u, x);
    f[u] = fv, g[u] = gv;
    ans1 = (ans1 + f[u]) % P;
    ans2 = (ans2 + g[u]) % P;
    for (int v : e[u]) if (v != pre) {
        dfs(v, u);
        ans1 = (ans1 + 1ll * f[u] * f[v]) % P;
        ans2 = (ans2 + 1ll * f[u] * g[v]) % P;
        ans2 = (ans2 + 1ll * g[u] * f[v]) % P;
        f[u] = (f[u] + 1ll * f[v] * fv) % P;
        g[u] = (g[u] + 1ll * f[v] * gv) % P;
        g[u] = (g[u] + 1ll * g[v] * fv) % P;
    }
}
long long fac[MAXN], inv[MAXN], pre[MAXN], suf[MAXN];
long long lagrange(long long y[], long long k, long long n) {
    if (n <= k) return y[n];
    fac[0] = 1, pre[0] = 1, suf[k + 1] = 1;
    for (int i = 1; i <= k; i++) {
        fac[i] = 1ll * fac[i - 1] * i % P;
        pre[i] = 1ll * pre[i - 1] * (n - i) % P;
    }
    inv[k] = qpow(fac[k], P - 2);
    for (int i = k; i >= 1; i--) {
        suf[i] = 1ll * suf[i + 1] * (n - i) % P;
        inv[i - 1] = 1ll * inv[i] * i % P;
    }
    int ans = 0;
    for (int i = 1; i <= k; i++) {
        ans = (ans + 1ll * y[i] * inv[k - i] % P * inv[i - 1] % P * 
            (((k - i) & 1) ? P - 1 : 1) % P * pre[i - 1] % P * suf[i + 1]) % P;
    }
    return ans;
}
pair<int, int> solve() {
    pair<int, int> ret;
    for (int i = 0; i < s.size() - 1; i++) {
        long long L = s[i], R = s[i + 1] - 1;
        for (long long i = L; i <= min(R, L + B); i++) {
            memset(f, 0, sizeof f);
            memset(g, 0, sizeof g);
            x = i;
            ans1 = ans2 = 0;
            dfs(1, 0);
            a1[i - L + 1] = ans1;
            a2[i - L + 1] = ans2;
        }
        for (int i = 1; i <= min(R, L + B) - L + 1; i++) {
            a1[i] = (a1[i - 1] + a1[i]) % P;
            a2[i] = (a2[i - 1] + a2[i]) % P;
        }
        ret.first = (ret.first + lagrange(a1, min(R, L + B) - L + 1, R - L + 1)) % P;
        ret.second = (ret.second + lagrange(a2, min(R, L + B) - L + 1, R - L + 1)) % P;
    }
    return ret;
}
int main() {
    // freopen("tree3.in", "r", stdin);
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= n; i++) {
        scanf("%lld%lld", &l[i], &r[i]);
        s.push_back(l[i]);
        s.push_back(r[i]);
        s.push_back(l[i] + k);
        s.push_back(r[i] + k);
    }
    s.push_back(0);
    s.push_back(5000000001);
    sort(s.begin(), s.end());
    s.erase(unique(s.begin(), s.end()), s.end());
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    k++;
    pair<int, int> ans1 = solve();
    k--;
    pair<int, int> ans2 = solve();
    printf("%d\n%d\n", (ans1.first - ans2.first + P) % P, (ans1.second - ans2.second + P) % P);
    return 0;
}
posted @ 2023-02-03 17:42  APJifengc  阅读(36)  评论(0编辑  收藏  举报