异象石
题意描述
首先声明:神仙题目。
求一棵树点集的最短路径覆盖,支持修改和查询。
树的点集最短路径覆盖指:使所有点联通的最短路径长度。
算法分析
首先可以康康这道题,这相当于是 \(3\) 个点的特殊情况。(而且还没有这么毒瘤的修改)
一步步来,首先考虑两个节点,显然答案为经过其 lca 的简单路径。
三个节点就稍微复杂:(建议自行画图理解)
假设三个节点为 \(x,y,z\),那么我们可以两两求出三点的 lca,显然有两个是重复的,且未重复的深度较深。
假设 \(lca(x,y)\) 为较深的节点,那么答案为 \(dis(x,y)+dis(lca(x,y),z)\)。
究其本质,发现选取较浅节点就会导致重复,而产生较深节点的原因是两点处于同一子树。
所以尽量选取同一子树的节点(且子树深度越深越好)可以得到最优的答案。
那么根据 dfs 的定义可以轻松地求出属于同一子树的两点(按 dfs 序排序),然后进行 \(n\) 个点的求解。
具体一点:
假设有 \(k\) 个节点,分别为 \(p_1,p_2,...,p_k\),则将其按照 dfs 序从小到大排序后,
\(ans=(dis(p_1,p_2)+dis(p_2,p_3)+...+dis(p_{k-1},p_k)+dis(p_k,p_1))/2\)。
同时因为本题还要支持修改,所以需要实时维护 dfs 序排序后的序列。
假设插入节点 \(a\),其 dfs 序为 \(dfn[a]\)。
插入排序后得到 \(a\)节点的前驱后继分别为 \(pre,nxt\)。
那么修改后 \(ans+=dis(a,pre)+dis(a,nxt)-dis(pre,nxt)\)。
删除只需要执行反操作即可,但要记得删除相应的节点。
所以我们需要一个支持插入、排序、删除、查询前驱后继的数据结构。(平衡树)
显然 STL set 即可胜任,于是本题完美解决。
代码实现
首先我们要了解一些 set 的食用姿势:
- 声明:类似
set<int>s
,支持结构体,但需要自行重载小于号。 - 插入:
s.insert(x)
。 - 查找:
s.find(x)
,如果没有返回s.end()
,否则返回该元素的迭代器,可用*
取消引用。 - 删除:
s.erase(x)
,如果x
为迭代器就删除其指向的元素,否则删除所有该元素。 - 传参:
void work(set<int>::iterator x)
,这里传的是迭代器,注意一下就好。 - 前驱后继:简单的
++
和--
操作。 - 首尾元素:
s.begin()
和--s.end()
,注意是左闭右开区间。
然后就可以快乐 AC 了(确信)。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<set>
#define N 100010
using namespace std;
int n,m,tot=0;
int cnt=0,head[N];
int f[N][30];
int dfn[N],dep[N];
long long dis[N];
long long ans=0;
struct Edge{
int nxt,to,val;
}ed[N<<1];
struct node{
int x,y;
};
bool operator < (const node a,const node b){
return (a.y==b.y?a.x<b.x:a.y<b.y);
}
set<node>s;
int read(){
int x=0,f=1;char c=getchar();
while(c<'0' || c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0' && c<='9') x=x*10+c-48,c=getchar();
return x*f;
}
void add(int u,int v,int w){
ed[++cnt]=(Edge){head[u],v,w};
head[u]=cnt;
return;
}
void dfs(int u,int fa){
dfn[u]=++tot;
dep[u]=dep[fa]+1;
f[u][0]=fa;
for(int i=head[u];i;i=ed[i].nxt){
int v=ed[i].to,w=ed[i].val;
if(v==fa) continue;
dis[v]=(long long)dis[u]+w;
dfs(v,u);
}
return;
}
void init(){
dfs(1,0);
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
f[j][i]=f[f[j][i-1]][i-1];
return;
}
int lca(int x,int y){
if(dep[x]>dep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if(dep[f[y][i]]>=dep[x]) y=f[y][i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
node get_pre(set<node>::iterator x){
if(x==s.begin()) return(*(--s.end()));
return *(--x);
}
node get_nxt(set<node>::iterator x){
if(x==--s.end()) return(*(s.begin()));
return *(++x);
}
long long get_ans(int x,int y){
return (long long)(dis[x]+dis[y]-2*dis[lca(x,y)]);
}
int main(){
n=read();
for(int i=1;i<n;i++){
int u=read(),v=read(),w=read();
add(u,v,w),add(v,u,w);
}
init();
m=read();
char typ;
for(int i=1;i<=m;i++){
typ=getchar();
while(typ!='+' && typ!='-' && typ!='?') typ=getchar();
if(typ=='?'){printf("%lld\n",ans/2);continue;}
int a=read();
node now;
now.x=a,now.y=dfn[a];
if(typ=='+'){
s.insert(now);
int pre=(get_pre(s.find(now))).x;
int nxt=(get_nxt(s.find(now))).x;
ans+=get_ans(a,pre)+get_ans(a,nxt)-get_ans(pre,nxt);
}else{
int pre=(get_pre(s.find(now))).x;
int nxt=(get_nxt(s.find(now))).x;
ans-=get_ans(a,pre)+get_ans(a,nxt)-get_ans(pre,nxt);
s.erase(s.find(now));
}
}
return 0;
}
完结撒花。