BZOJ3991 [SDOI2015]寻宝游戏 【dfs序 + lca + STL】
题目
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
输入格式
第一行,两个整数N、M,其中M为宝物的变动次数。
接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。
接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。
输出格式
M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。
输入样例
4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1
输出样例
0
100
220
220
280
提示
1<=N<=100000
1<=M<=100000
对于全部的数据,1<=z<=10^9
题解
其实,,不用建虚树,只是用虚树的思想
稍经模拟可以发现,最后的答案 = 所有点建出的虚树上的边长和 * 2
更简单的可以发现:\(ans=\)虚树中相邻dfs序的点距离之和【首尾也算相邻】
所以我们只需要维护一个dfs序集合
每次插入一个元素,就找到其前驱后继,减去前驱后继的距离,分别加上该点到前驱到后继的距离
每次删除一个元素,就找到其前去后缀,分别减去该点到前驱到后继的距离,加上前驱后继的距离
输出答案时在加上首尾之间的距离
比较懒就用set维护了,【貌似是第一次用set??我真是好学生】
#include<iostream>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = H[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
#define mp(a,b) make_pair<int,int>(a,b)
#define cp make_pair<int,int>
using namespace std;
const int maxn = 100005,maxm = 100005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
int H[maxn],ne = 2;
struct EDGE{int to,nxt; LL w;}ed[maxn << 1];
inline void build(int u,int v,LL w){
ed[ne] = (EDGE){v,H[u],w}; H[u] = ne++;
ed[ne] = (EDGE){u,H[v],w}; H[v] = ne++;
}
int dfn[maxn],dep[maxn],fa[maxn][18],vis[maxn],h[maxn],n,m,cnt;
LL ans,D[maxn];
set<int> S;
void dfs(int u){
dfn[u] = ++cnt; h[cnt] = u;
REP(i,17) fa[u][i] = fa[fa[u][i - 1]][i - 1];
Redge(u) if ((to = ed[k].to) != fa[u][0]){
fa[to][0] = u; dep[to] = dep[u] + 1;
D[to] = D[u] + ed[k].w;
dfs(to);
}
}
int lca(int u,int v){
if (dep[u] < dep[v]) swap(u,v);
for (int i = 0,d = dep[u] - dep[v]; (1 << i) <= d; i++)
if (d & (1 << i)) u = fa[u][i];
if (u == v) return u;
for (int i = 17; i >= 0; i--)
if (fa[u][i] != fa[v][i]){
u = fa[u][i];
v = fa[v][i];
}
return fa[u][0];
}
LL dis(int u,int v){
int o = lca(u,v);
return D[u] + D[v] - 2 * D[o];
}
void cut(int u){
int pre = *--S.find(dfn[u]),post = *++S.find(dfn[u]);
if (pre >= 1) ans -= dis(h[pre],u);
if (post <= n) ans -= dis(h[post],u);
if (pre >= 1 && post <= n) ans += dis(h[pre],h[post]);
S.erase(dfn[u]);
}
void ins(int u){
S.insert(dfn[u]);
int pre = *--S.find(dfn[u]),post = *++S.find(dfn[u]);
if (pre >= 1) ans += dis(h[pre],u);
if (post <= n) ans += dis(h[post],u);
if (pre >= 1 && post <= n) ans -= dis(h[pre],h[post]);
}
int main(){
n = read(); m = read();
int u,v; LL w;
for (int i = 1; i < n; i++){
u = read(); v = read(); w = read();
build(u,v,w);
}
dfs(1);
S.insert(0); S.insert(n + 1);
while (m--){
u = read();
if (vis[u]) cut(u);
else ins(u);
LL add = 0;
int first = *++S.find(0),last = *--S.find(n + 1);
if (first >= 1 && last <= n) add = dis(h[first],h[last]);
printf("%lld\n",ans + add);
vis[u] ^= 1;
}
return 0;
}