P9755 [CSP-S 2023] 种树 题解
upd on 2023.11.20:增加细节说明。
刚开始浪费太多时间了,导致赛时没调出来,有点火大。
如果一开始没有头绪的话可以先看一下特殊性质,链是直接贪心。
考虑一下菊花,发现直接贪心是不可做的,但是发现树的高度随时间增大而增大,可以用二分转化为判定性问题解决。
对于任意的树来说,参考菊花,考虑二分出时间,转化为判定性问题。可以通过求根公式或者再套一层二分求出每个点最晚被种树的时间。然后贪心的按照求出的每个点最晚植树的时间 \(t\) 排序,从小到大的植树即可,由于每个点最多只会被染色一次,可以直接暴力染色,遇到访问过的终止即可,均摊下来是 \(\mathcal{O}(1)\) 的。
这样的话复杂度是 \(\mathcal{O}(q\log T\log n)\) 的,\(T\) 为最少天数,这里是 \(10^9\)。但是其实是可以去掉后面的 \(\log n\) 的,\(t_i\) 是可以 \(\mathcal{O}(1)\) 的用求根公式求出,最后的排序可以用桶,因为 \(t_i>n\) 的是无所谓的。这样的话就是单 \(\log\) 的了。空间线性。
注意有些地方可能会爆 long long。
代码里的二分写法可能和其他的不是很一样。
\(c_i < 0\) 时要注意 \(tim\) 和 \(mxc\) 的关系,因为生长的总高度是 \((b + mid\times c)+(b+(mid+1)\times c+\dots+(b + \min(mxc, tim)\times c))+\max(0, tim-mxc)\),即增长速度大于 \(1\) 的段不一定能将 \(mxc\) 天都取满,其中 \(mxc\) 是多少时间过后,增长速度恒为 \(1\),\(mid\) 和 \(tim\) 分别表示从第几天开始长和第几天全部长好。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define i128 __int128
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 1e5 + 10;
int n;
int b[N], c[N], fa[N];
ll a[N];
int head[N], cnt_e;
struct edge{
int v, nxt;
} e[N << 1];
void adde(int u,int v){
++cnt_e, e[cnt_e].v = v, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
void dfs(int u, int ff){
fa[u] = ff;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v;
if(v == ff) continue;
dfs(v, u);
}
}
int stk[N], tp, vis[N], id[N];
ll t[N];
i128 calc(int i, int tl, int tr, int mxc) {
if(mxc < tl) return tr - tl + 1;
if(mxc > tr) return (tr - tl + 1) * 1ll * b[i] + (tr - tl + 1) * ((i128)(tl + tr)) / 2 * c[i];
return (mxc - tl + 1) * 1ll * b[i] + (mxc - tl + 1) * ((i128)tl + mxc) / 2 * c[i] + tr - mxc;
}
bool check(int tim){
for(int i = 1; i <= n; ++i) {
if(c[i] < 0) {
int mxc = ((b[i] - 1) / abs(c[i]));
int l = 1, r = n;
while(l <= r) {
int mid = (l + r) >> 1;
if(calc(i, mid, tim, mxc) >= ((i128)a[i])) l = mid + 1;
else r = mid - 1;
}
t[i] = r;
} else if(!c[i]) {
t[i] = tim - (a[i] + b[i] - 1) / b[i] + 1;
} else {
int l = 1, r = n;
while(l <= r) {
int mid = (l + r) >> 1;
if(b[i] * 1ll * (tim - mid + 1) + (tim + mid) * ((i128)tim - mid + 1) * c[i] / 2 >= ((i128)a[i]))
l = mid + 1;
else r = mid - 1;
}
t[i] = r;
}
if(t[i] <= 0) return 0;
}
for(int i = 1; i <= n; ++i) id[i] = i, vis[i] = 0;
sort(id + 1, id + n + 1, [](int a, int b) {
return t[a] < t[b];
});
int now = 0; tp = 0;
for(int i = 1; i <= n; ++i) {
int u = id[i];
for(; !vis[u]; u = fa[u]) vis[stk[++tp] = u] = 1;
while(tp)
if(t[stk[tp--]] < ++now)
return 0;
}
return 1;
}
bool Med;
int main(){
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) cin >> a[i] >> b[i] >> c[i];
for(int i = 1, u, v; i < n; ++i) cin >> u >> v, adde(u, v), adde(v, u);
dfs(1,0);
vis[0] = 1;
int l = n, r = 1e9;
while(l <= r) {
int mid = (l + r) >> 1;
if(check(mid)) r = mid - 1;
else l = mid + 1;
}
cout << l << "\n";
cerr << TIME << "ms\n";
return 0;
}