CF671D(set 维护整体 dp)
翻别人博客的时候翻到的一道题
- 给定一棵 n 个点的以 1 为根的树。
- 有 m 条路径 (x,y),保证 y 是 x 或 x 的祖先,每条路径有一个权值。
- 你要在这些路径中选择若干条路径,使它们能覆盖每条边,同时权值和最小。
- \(n,m \le 3 \times 10^5\)
首先可以想到一个显然的 dp。
设 \(dp[i][j]\) 表示以 i 为根的子树,向上延申了 j 个点。然后转移就是了复杂度 \(O(n^2)\)
但这个复杂度不太能过得去,我们换一种形式。
设 \(dp[i][j]\) 表示在 i 为根的子树,向上支配到 j 深度。并且设 \(f(i)=\min\limits_{i=1}^{dep[i]-1}dp[i][j]\)
然后有
\[dp[i][j] = \sum_{v\in son(i)}f(v)-\min_{v\in son(i)}(dp[v][j]-f(v))
\]
\[dp[i][anc] = \sum_{v\in son(i)}f(v)+c
\]
仔细分析一波,第一个式子相当于一个线段树合并维护 dp,而第二个式子则是在做一个全局加法。
可以想到使用线段树合并来维护 dp。
不过 \(\operatorname{256MB}\) 空间复杂度好像不太能过得去。
有一种奇妙的做法是用 set
来维护这个 dp。
具体来说,每个节点维护一个 set
,里面存放形如 \((j,dp[i][j])\) 的二元组,转移的时候相当于做区间加法、合并两个 set
,并且需要支持随时取出 \(f(i)\) 值
注意到我们可以随时维护二元组的第一位 j,不超过 \(dep[i]\) 并且对每一个 \(j\) 只有一个二元组,这样前两个操作非常容易处理。
而对于第二个操作,数据结构的优势已经基本用尽,要还想随时维护的话,就得继续嵌套其他数据结构。这时候注意到一个贪心性质,set
中的二元组是单调递减的,所以只需要修改的时候顺便维护 set
的单调性就行。
代码细节一般,注意合并 set
的时候启发式合并。
// 代码中没有启发式合并就草过去了/xk
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pb emplace_back
#define pii pair<int,int>
template<typename _T>
inline void read(_T &x)
{
x= 0 ;int f =1;char s=getchar();
while(s<'0' ||s>'9') {f =1;if(s == '-')f =- 1;s=getchar();}
while('0'<=s&&s<='9'){x = (x<<3) + (x<<1) + s - '0';s= getchar();}
x*=f;
}
const int np = 3e5 + 5;
const int INF = 1 << 30;
int head[np],ver[np * 2],nxt[np * 2],tit;
int f[np],id[np];
vector<pii> vec[np];
int tmp[np],dep[np],son[np];
int n,m,pre[np],siz[np];
set<pii> s[np];
int tag[np];
inline void add(int x,int y)
{
ver[++tit] = y;
nxt[tit] = head[x];
head[x] = tit;
}
inline void solve(int x)
{
int a1(-1),a2(-1);
for(set<pii>::iterator it = s[id[x]].begin(),iter;it != s[id[x]].end();it ++)
{
iter = it;
int a1_ = (*it).first;
int a2_ = (*it).second;
if(a1==-1)
{
a1 = a1_;
a2 = a2_;
continue;
}
if(a2 <= a2_){
iter++;
s[id[x]].erase(it);
it = iter;
it--;
}else{
a1 = a1_;
a2 = a2_;
}
}
}
inline void ins(int u,int j,int val)
{
val += tag[id[u]];
if(j > dep[u]) return ;
set<pii>::iterator it = s[id[u]].lower_bound((pii){j,-INF});
set<pii>::iterator iter,aux;
if(it == s[id[u]].end() || (*it).first != j)
{
s[id[u]].insert((pii){j,val});
iter = s[id[u]].lower_bound((pii){j,val});
// it = iter;
// if(iter == s[id[u]].begin()) return ;
// iter--;
// if((*iter).second <= val){
// s[id[u]].erase(it);
// }
}
else{
if((*it).second < val) return ;
s[id[u]].erase(it);
s[id[u]].insert((pii){j,val});
}
}
inline void dfs(int x,int ff)
{
int F(0);
dep[x] = dep[ff] + 1;
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
dfs(v,x);
F += f[v];
if(f[v] == -1){
f[x]=-1;
return ;
}
}
tag[id[x]] += F;
for(auto pi:vec[x]){
int anc = pi.first;
int sd = pi.second;
ins(x,dep[anc],sd);
}
ins(x,dep[x],0);
for(int i=head[x],v;i;i=nxt[i])
{
v = ver[i];
if(v == ff) continue;
// if(s[x].size() > s[v].size()) swap(id[x],id[v]);
for(set<pii>::iterator it = s[id[v]].begin();it!=s[id[v]].end();it ++)
{
int j = (*it).first;
int val = (*it).second;
ins(x,j,val-f[v]);
}
}
solve(x);
// cerr<< x <<" : ";
// printf("%d : ",x);
// for(auto i:s[id[x]])
// {
// cerr<<"("<<i.first<<","<<i.second<<")";
// }
// cerr<<'\n';
// puts("");
if(x == 1) return;
// cout<<((*(s[x].rbegin())).second)<<'\n';
if((*(s[id[x]].rbegin())).first != dep[x])f[x] = (*(s[id[x]].rbegin())).second;
else {
set<pii >::iterator it = s[id[x]].end();
--it;
if(it == s[id[x]].begin()) {
f[x]=-1;
return ;
}//f[x] =-1;
it--;
f[x] = (*it).second;
}//f[x] = (*(--s[x].rbegin())).second;
// printf("%d\n",f[x]);
}
signed main()
{
read(n),read(m);
for(int i=1,x,y;i <= n- 1;i ++)
{
read(x),read(y);
add(x,y),add(y,x);
}
for(int i=1,x,y,val;i <= m;i ++){
read(x),read(y),read(val);
vec[x].pb((pii){y,val});
}
for(int i=1;i <= n;i ++) id[i] = i;
dfs(1,0);
if(f[1] == -1)
{
puts("-1");
return 0;
}
printf("%lld\n",(*(s[1].lower_bound((pii){1,-INF}))).second);
}