题解 [AGC007E] Shik and Travel
首先可以想到二分答案转化为判定
然后我猜测可以贪心每次选能到的点中最大的,但沈老师说假了我也不知道为什么
然后正解:
一个 \(O(炸天)\) 的 DP 是令 \(f_u(i, j)\) 为 \(u\) 子树内到最先到的点路径权值为 \(i\),到最后到的点路径权值为 \(j\) 且所有其它路径权值不超过 \(mid\) 是否可行
转移考虑枚举跨过 \(u\) 的路径权值
\[f_u(a, b)=f_{lson_u}(a, i)\&f_{rson_u}(j, b)[i+j+w(u, lson_u)+w(u, rson_u)\leqslant mid]
\]
然后这么个复杂度炸天应该没人回去想的东西是可以优化的
发现当 \(a_1\leqslant a_2\and b_1\leqslant b_2\) 时 \(f_u(a_2, b_2)\) 是没有用的
那么只需要对每个可行的 \(a\) 维护最小的 \(b\) 即可,可以排序后使用双指针转移
发现合并 \(lson, rson\) 产生的状态数为 \(2\min\{siz_{lson}, siz_{rson}\}\),那么状态数是 log 级别的
加上二分的 log 总复杂度 \(O(n\log^2 n)\)
- 一些复杂度很高的 DP 的状态可能带有一些「一个状态的每一维都严格大于另一个状态则那个状态没用」的性质
此时可以使用单调性只保留又可能成为最优解的状态
p.s. 代码里使用了 sort 所以是 \(O(n\log^3 n)\) 的,但是可以归并去掉一个 log
而且代码里也为归并预留了接口
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
ll sum, mid;
vector<pair<ll, ll>> f[N];
vector<pair<int, int>> to[N];
void dfs(int u) {
f[u].clear();
if (!to[u].size()) {f[u].pb({0, 0}); return ;}
vector<pair<ll, ll>> sta[2], tem;
for (int i=0; i<2; ++i) dfs(to[u][i].fir);
for (int i=0; i<2; ++i) {
ll lim=mid-to[u][0].sec-to[u][1].sec;
int ls=to[u][0^i].fir, rs=to[u][1^i].fir;
for (int p1=0,p2=0; p1<f[ls].size(); ++p1) {
while (p2+1<f[rs].size() && f[ls][p1].sec+f[rs][p2+1].fir<=lim) ++p2;
if (p2<f[rs].size() && f[ls][p1].sec+f[rs][p2].fir<=lim) sta[i].pb({f[ls][p1].fir+to[u][0^i].sec, f[rs][p2].sec+to[u][1^i].sec});
}
}
for (int i=0; i<2; ++i) for (auto it:sta[i]) tem.pb(it);
sort(tem.begin(), tem.end());
ll lst=INF;
for (int i=0; i<tem.size(); ++i) if (tem[i].sec<lst)
f[u].pb(tem[i]), lst=tem[i].sec;
}
bool check(ll mid) {
// cout<<"check: "<<mid<<endl;
::mid=mid;
dfs(1);
// cout<<"---f---"<<endl; for (int i=1; i<=n; ++i) {cout<<i<<": "; for (auto it:f[i]) cout<<"("<<it.fir<<','<<it.sec<<") "; cout<<endl;}
return f[1].size();
}
signed main()
{
n=read();
for (int i=2,fa,val; i<=n; ++i) {
fa=read(); sum+=(val=read());
to[fa].pb({i, val});
}
ll l=0, r=sum, mid;
while (l<=r) {
mid=(l+r)>>1;
if (check(mid)) r=mid-1;
else l=mid+1;
}
printf("%lld\n", r+1);
return 0;
}