省选联考 2022 填树

洛谷传送门

LOJ 传送门

这题做得真艰难。

先考虑第一问。

一眼看上去并没有什么复杂度脱离值域的办法。考虑枚举一个 x 表示最小值,那么点权只能在 [x,x+K] 中。

点权最小值不一定为 x,减去点权在 [x+1,x+K] 中的答案即可,也就是把 K1 后再算一遍。

那么可以得出每个点权的取值范围为 [max(x,li),min(x+K,ri)]

设第 i 个点有 au 种取值。答案就是树上所有简单路径的 au 乘积之和。

那么很容易做一个 dp,可以算出 fu 表示 u 子树内延伸到 u 的路径中,每条路径的取值之和。

合并儿子时有 fuau×fv

然后考虑所有 LCAu 的点对的贡献,相当于选 v1sonu,v2sonu,v1v2,能产生 fv1×fv2×au 的贡献。

我们发现,[max(x,li),min(x+K,ri)] 只能组合出来 4 种取值范围:[x,x+K],[x,ri],[li,x+K],[li,ri],并且只和 xliK,li,riK,ri 的大小关系有关。

所以我们有 O(n) 个断点,在每相邻两个断点组成的左开右闭区间 [L,R) 内,au 可以表示成关于 x 的至多一次项式 Ax+B

我们现在希望计算树上所有简单路径的多项式 au 的乘积之和,可以使用上述的 dp 做法求出,多项式乘法暴力就行。

设我们最后求出来的树上所有简单路径的多项式 au 的乘积之和为 i=0nAixi。答案就是 i=0nAix=LR1xi。也就是说要快速算 x=0NxM

这就是 CF622F The Sum of the k-th Powers。这个和就是一个 M+1 次多项式,直接拉格朗日插值即可。注意讨论 L,R 中有负数的情况。

然后考虑第二问。

仍然先考虑暴力。设第 u 个点所有取值之和为 bu,那么对于树上一条简单路径 p1,p2,,pk,我们希望求 i=1kbpijiapj

这个也可以 dp 求出。设 fu,0/1 表示一条从 u 子树内延伸到 u 的路径,中间是否有一个点乘的是 bi 而不是 ai

合并儿子时有转移 fu,1aufv,1+bufv,0

然后仍然考虑所有 LCAu 的点对的贡献,相当于选 v1sonu,v2sonu,v1v2,能产生 fv1,0×fv2,0×bu+fv1,1×fv2,0×au+fv1,0×fv2,1×au 的贡献。

然后也可以像第一问一样,先分段,然后把 bu 表示成关于 x 的至多二次项式,dp 后拉格朗日插值算 x=0NxM 解决。

时间复杂度 O(n3),但是好像跑得比大多数做法都快?

code
// Problem: P8290 [省选联考 2022] 填树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P8290
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 2020;
const ll mod = 1000000007;
const ll inv2 = (mod + 1) / 2;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, a[maxn], b[maxn], lsh[maxn], tot, pw[maxn][maxn], fac[maxn], ifac[maxn];
vector<int> G[maxn];
typedef vector<ll> poly;
inline poly operator + (const poly &a, const poly &b) {
int n = (int)a.size(), m = (int)b.size();
poly res(max(n, m));
for (int i = 0; i < max(n, m); ++i) {
if (i < n) {
res[i] += a[i];
}
if (i < m) {
res[i] += b[i];
}
(res[i] >= mod) && (res[i] -= mod);
}
return res;
}
inline poly operator * (const poly &a, const poly &b) {
if (a.empty() || b.empty()) {
return poly();
}
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(n + m + 1);
for (int i = 0; i <= n; ++i) {
for (int j = 0; j <= m; ++j) {
res[i + j] = (res[i + j] + a[i] * b[j]) % mod;
}
}
return res;
}
ll pre[maxn], suf[maxn];
// 0 ^ m + 1 ^ m + 2 ^ m + ... + n ^ m
inline ll calc(ll n, ll m) {
if (n <= 0) {
return 0;
}
if (n <= m + 5) {
ll ans = 0;
for (int i = 0; i <= n; ++i) {
ans = (ans + pw[i][m]) % mod;
}
return ans;
}
pre[0] = 1;
for (int i = 1; i <= m + 2; ++i) {
pre[i] = pre[i - 1] * (n - i) % mod;
}
suf[m + 3] = 1;
for (int i = m + 2; i; --i) {
suf[i] = suf[i + 1] * (n - i) % mod;
}
ll s = 0, ans = 0;
for (int i = 1; i <= m + 2; ++i) {
s = (s + pw[i][m]) % mod;
ll coef = pre[i - 1] * suf[i + 1] % mod;
coef = coef * ifac[i - 1] % mod * ifac[m + 2 - i] % mod;
if ((m + 2 - i) & 1) {
coef = (mod - coef) % mod;
}
ans = (ans + coef * s) % mod;
}
return ans;
}
// l ^ m + (l + 1) ^ m + (l + 2) ^ m + ... + r ^ m
inline ll calc(ll l, ll r, ll m) {
if (!m) {
return (r - l + 1) % mod;
}
if (l <= 0 && r <= 0) {
return (mod + ((m & 1) ? (-calc(-l, m) + calc(-r - 1, m)) : (calc(-l, m) - calc(-r - 1, m)))) % mod;
} else if (l <= 0 && r > 0) {
return (mod + mod + ((m & 1) ? -calc(-l, m) : calc(-l, m)) + calc(r, m)) % mod;
} else {
return (calc(r, m) - calc(l - 1, m) + mod) % mod;
}
}
poly A[maxn], B[maxn], P, Q, F[maxn][2];
void dfs(int u, int fa) {
F[u][0] = A[u];
F[u][1] = B[u];
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs(v, u);
F[u][0] = F[u][0] + A[u] * F[v][0];
F[u][1] = F[u][1] + A[u] * F[v][1] + B[u] * F[v][0];
}
P = P + F[u][0];
Q = Q + F[u][1];
}
void dfs2(int u, int fa) {
poly a, b;
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs2(v, u);
P = P + F[v][0] * A[u] * a;
Q = Q + F[v][0] * A[u] * b + F[v][1] * A[u] * a + F[v][0] * B[u] * a;
a = a + F[v][0];
b = b + F[v][1];
}
}
inline pii calc(ll m) {
tot = 0;
for (int i = 1; i <= n; ++i) {
lsh[++tot] = a[i];
lsh[++tot] = a[i] - m;
lsh[++tot] = b[i] - m;
lsh[++tot] = b[i] + 1;
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
ll ans1 = 0, ans2 = 0;
for (int _ = 1; _ < tot; ++_) {
ll L = lsh[_], R = lsh[_ + 1];
for (int i = 1; i <= n; ++i) {
A[i] = B[i] = poly();
ll l = a[i], r = b[i];
if (max(l, L) > min(r, L + m)) {
continue;
}
if (L >= l && L >= r - m) {
A[i] = poly(2);
A[i][0] = r + 1;
A[i][1] = mod - 1;
B[i] = poly(3);
B[i][0] = (r * r + r) % mod * inv2 % mod;
B[i][1] = inv2;
B[i][2] = (mod - inv2) % mod;
} else if (L >= l && L < r - m) {
A[i] = poly(1);
A[i][0] = m + 1;
B[i] = poly(2);
B[i][0] = m * (m + 1) % mod * inv2 % mod;
B[i][1] = m + 1;
} else if (L < l && L >= r - m) {
A[i] = poly(1);
A[i][0] = r - l + 1;
B[i] = poly(1);
B[i][0] = calc(l, r, 1);
} else if (L < l && L < r - m) {
A[i] = poly(2);
A[i][0] = (m - l + 1 + mod) % mod;
A[i][1] = 1;
B[i] = poly(3);
ll p = (l + m) % mod, q = (m - l + 1 + mod) % mod;
B[i][0] = p * q % mod * inv2 % mod;
B[i][1] = inv2 * (p + q) % mod;
B[i][2] = inv2;
}
}
P = Q = poly();
dfs(1, -1);
dfs2(1, -1);
for (int i = 0; i < (int)P.size(); ++i) {
ans1 = (ans1 + P[i] * calc(L, R - 1, i)) % mod;
}
for (int i = 0; i < (int)Q.size(); ++i) {
ans2 = (ans2 + Q[i] * calc(L, R - 1, i)) % mod;
}
// printf("%lld %lld %lld\n", L, R, ans2);
// for (int i = 1; i <= n; ++i) {
// printf("i: %d\n", i);
// for (ll x : B[i]) {
// printf("%lld ", x);
// }
// putchar('\n');
// }
}
return mkp(ans1, ans2);
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%lld%lld", &a[i], &b[i]);
}
int up = n * 2 + 5;
for (int i = 0; i <= up; ++i) {
pw[i][0] = 1;
for (int j = 1; j <= up; ++j) {
pw[i][j] = pw[i][j - 1] * i % mod;
}
}
fac[0] = 1;
for (int i = 1; i <= up; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[up] = qpow(fac[up], mod - 2);
for (int i = up - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
pii x = calc(m), y = calc(m - 1);
printf("%lld\n%lld\n", (x.fst - y.fst + mod) % mod, (x.scd - y.scd + mod) % mod);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
posted @   zltzlt  阅读(22)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示