20221114_T4B_树形dp换根dp
题意
太冗长了
传一张图片自己看吧。
题解
赛时得分 15/100/100
赛时写了 \(A=0\) 的乱搞,没写对但是拿了 15pts。
首先这个函数是一个增函数,对于 power 和 score 两个指标来说。
既然是增函数,那么我们显然是当前 score 越大越好,并且 power 具有单调性,可以二分。
我们首先二分 power,然后选择树形 dp。
首先一轮记录从叶子节点走上来的 score,power在整个过程中是不会改变的。
然后我们考虑记录从父亲节点下来的到子节点的最优秀的一条,走到一个叶节点计算答案就好了。注意要维护最大值和次大值,因为这个叶节点不能是我从父亲走下来的叶节点。
代码
#include <bits/stdc++.h>
using namespace std;
template <typename T>inline void read(T& t){t=0; register char ch=getchar(); register int fflag=1;while(!('0'<=ch&&ch<='9')) {if(ch=='-') fflag=-1;ch=getchar();}while(('0'<=ch&&ch<='9')){t=t*10+ch-'0'; ch=getchar();} t*=fflag;}
template <typename T,typename... Args> inline void read(T& t, Args&... args) {read(t);read(args...);}
const int N = 1e5 + 10, inf = 0x3f3f3f3f;
typedef long double db;
int n, A, du[N], fa[N], fr1[N], fr2[N];
db ori, tar, f1[N], f2[N], g[N];
const db eps = 1e-8;
struct person {
db power, score;
void R() {cin >> power >> score;}
db sig(db x) {
if(fabs(x) < eps) return 0;
if(x > 0) return 1;
else return -1;
}
db YeMenYaoZhanDou(person x) {
db dpow = power - x.power, dpoi = score - x.score;
db delta = 2 * sig(dpow) * (sqrt(fabs(dpow) + 1) - 1) - A * sig(dpoi) * (sqrt(fabs(dpoi) + 1) - 1);
return x.score - delta;
}
}a[N];
vector<int>G[N];
db ans;
void dfs(int u, db power) {
f1[u] = f2[u] = -1e9; fr1[u] = fr2[u] = 0;
if(du[u] <= 1 && u != 1) {f1[u] = a[u].YeMenYaoZhanDou((person){power, ori}); return;}
for(int v : G[u]) {
if(v == fa[u]) continue; fa[v] = u;
dfs(v, power);
db zhan = a[u].YeMenYaoZhanDou((person){power, f1[v]});
if(f2[u] < zhan) f2[u] = zhan, fr2[u] = v;
if(f1[u] < f2[u]) swap(f1[u], f2[u]), swap(fr1[u], fr2[u]);
}
}
void dfs1(int u, db power) {
if(du[u] <= 1) {ans = max(ans, g[u]);} // solved
for(int v : G[u]) {
if(v == fa[u]) continue;
db up = (v == fr1[u]) ? f2[u] : f1[u]; up = max(up, g[u]);
g[v] = a[v].YeMenYaoZhanDou((person){power, up});
dfs1(v, power);
}
}
bool check(db x) {
ans = -1e18; fa[1] = 0; if(du[1] <= 1) {g[1] = a[1].YeMenYaoZhanDou((person){x, ori});} else g[1] = -1e18;
dfs(1, x); dfs1(1, x);
return ans >= tar;
}
int main() {
freopen("pigeatyy.in", "r", stdin);
freopen("pigeatyy.out", "w", stdout);
read(n, n); cin >> ori >> tar >> A;
for(int i = 1; i < n; ++i) {
int u, v;
read(u, v);
G[u].push_back(v);
G[v].push_back(u);
du[u]++; du[v]++;
}
for(int i = 1; i <= n; ++i) a[i].R();
db l = -1e6, r = 1e6;
while(r - l > eps) {
db mid = (l + r) / 2;
if(check(mid)) r = mid;
else l = mid;
}
cout << fixed << setprecision(6) << l << endl;
return 0;
}