P3320 [SDOI2015]寻宝游戏
题意描述
给你一棵树,每个边都有一个边权。对于一个点有两个状态 \(0/1\) 表示这个点是否需要被经过。有 \(m\) 组询问,每次
会把一个点的状态取反。对于每组询问,回答出走完所有需要经过的点并回到起点的最短路径。
solution
考虑一下这个问题的简化版,给你一棵树,你需要经过一个点集中所有的点,求走的最短路程。
显然我们按照 \(dfs\) 序的顺序是最优的,因为这样保证了你不会走重复的道路(感性理解一下)。
那么最短路径就是相连两点之间的距离和加上起始点到终点的距离。
回到这个题,要求我们动态维护这个点集。支持加一个点删除一个点,动态维护答案。
平衡树? 太难写了,还是用 \(set\) 吧。
拿 \(set\) 维护一下这个按 \(dfs\) 序从小到大形成的一个序列。然后加入(删除)一个点维护一下对答案的贡献即
可。为了防止出现边界问题,可以在一开始加入一个 \(0\) 号点和一个 \(n+1\) 号点。
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<set>
using namespace std;
#define int long long
const int N = 1e5+10;
int n,m,num,tot,u,v,x,ans,w;
int head[N],dep[N],fa[N],siz[N],son[N],top[N],dfn[N],sum[N],id[N];
bool vis[N];
struct node
{
int to,net,w;
}e[N<<1];
set<int>s;
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
void add(int x,int y,int w)
{
e[++tot].to = y;
e[tot].w = w;
e[tot].net = head[x];
head[x] = tot;
}
void get_tree(int x)
{
dep[x] = dep[fa[x]] + 1; siz[x] = 1;
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x]) continue;
fa[to] = x;
sum[to] = sum[x] + e[i].w;
get_tree(to);
siz[x] += siz[to];
if(siz[to] > siz[son[x]]) son[x] = to;
}
}
void dfs(int x,int topp)
{
top[x] = topp; dfn[x] = ++num; id[dfn[x]] = x;
if(son[x]) dfs(son[x],topp);
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x] || to == son[x]) continue;
dfs(to,to);
}
}
int lca(int x,int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x,y);
x = fa[top[x]];
}
return dep[x] <= dep[y] ? x : y;
}
int dis(int x,int y)
{
if(x == 0 || y == 0) return 0;
return sum[x] + sum[y] - 2 * sum[lca(x,y)];
}
signed main()
{
n = read(); m = read();
for(int i = 1; i <= n-1; i++)
{
u = read(); v = read(); w = read();
add(u,v,w); add(v,u,w);
}
get_tree(1); dfs(1,1);
s.insert(0); s.insert(n+1);
for(int i = 1; i <= m; i++)
{
x = read();
vis[x] ^= 1;
set<int>::iterator itl,itr;
if(vis[x] == 1)
{
itr = s.lower_bound(dfn[x]);
itl = itr; --itl;//动态维护贡献
ans -= dis(id[*itl],id[*itr]);
ans += dis(x,id[*itl]);
ans += dis(x,id[*itr]);
s.insert(dfn[x]);
}
else
{
itr = s.lower_bound(dfn[x]);
s.erase(itr);
itr = s.lower_bound(dfn[x]);
itl = itr; --itl;
ans -= dis(x,id[*itl]);
ans -= dis(x,id[*itr]);
ans += dis(id[*itl],id[*itr]);
}
itl = s.begin(); itr = s.end();
++itl; --itr; --itr;
printf("%lld\n",ans + dis(id[*itl],id[*itr]));
}
return 0;
}