「国家集训队」旅游 题解 (树链剖分)
题目简介
对一颗节点为 \(N\) 树 (节点编号 \(0~N-1\)),进行 \(M\) 次操作:
C i w
将输入的第 \(i\) 条边权值改为 \(w\)N u v
将 \(u,v\) 节点之间的边权都变为相反数SUM u v
询问 \(u,v\)节点之间边权和MAX u v
询问 \(u,v\) 节点之间边权最大值MIN u v
询问 \(u,v\) 节点之间边权最小值
分析
其实就是树剖模板,不过需要支持的操作有点多。
SUM
,MAX
,MIN
操作都很简单,直接线段树维护三个信息就行了:
struct Segment_Tree{
int l,r;
int sum;
int maxv;
int minv;
bool tag;
}seg[Maxn<<2];
对于C
操作,我们可以略施小计,使用邻接表储存时从编号1开始储存,由于邻接表内两项的起点终点是刚好相反的,所以我们可以借此通过异或算法求出一条边在邻接表中出现的位置。
具体的,我们可以用一个pos
数组在第一遍dfs时记录到达 \(y\) 是通过邻接表内第几号边。在第二遍dfs更新id
和·rev
数组时,可以同时更新一个rpt
数组,记录邻接表编号所对应的线段树节点编号。
然后对于该操作:
modifit1(1,rpt[x<<1],y);
太玄学?不理解?等会儿看代码。
对于N
操作,可要用到tag
数组了。
我们这样定义一个tag
:某个节点及其子树需要变成相反数,则tag
=\(\mbox{true}\),否则为\(\mbox{false}\)。
如果一颗子树需要变成相反数,其父亲也需要变成相反数,那么它就不需要再变成相反数,只要由它伟大的父亲干这件事就行了。
tag
标记下传:
seg[ls(k)]^=1;
seg[rs(k)]^=1;
这里还有一个细节,由于相反数的特殊性质,最大的取反会变成最小的,最小的取反会变成最大的,因此不要忘记swap
一下。
\(AC\ Code\)
\(262\) 行的代码,干着就是爽啊。
#include<cstdio>
#include<string>
#include<iostream>
using namespace std;
int read(){
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*f;
}
const int Maxn=2e5+5;
const int Inf=0x3f3f3f3f;
struct Adjacency_List{
int nxt,t;
int val;
}tr[Maxn<<1];
int h[Maxn];
int tot=1;//从1开始存边
void Add(int x,int y,int z){
tr[++tot].nxt=h[x];
tr[tot].t=y;
tr[tot].val=z;
h[x]=tot;
}
struct node{
int id;
int fa,son;
int top,dep;
int siz;
int pos;
}tp[Maxn];
int rev[Maxn];
void dfs1(int x,int fa){
tp[x].dep=tp[fa].dep+1;
tp[x].fa=fa;
tp[x].siz=1;
int &p=tp[x].son;
for(int i=h[x];i;i=tr[i].nxt){
int y=tr[i].t;
if(y==fa)continue;
dfs1(y,x);
tp[y].pos=i;
// printf("pos [%d] = %d\n",y,tp[y].pos);
// printf("%d -> %d\n",x,y);
tp[x].siz+=tp[y].siz;
if(!p||tp[p].siz<tp[y].siz)p=y;
}
}
int rpt[Maxn<<1];
int ksiz;//线段树节点编号
void dfs2(int x,int fa){
tp[x].id=++ksiz;
rev[ksiz]=tr[tp[x].pos].val;
rpt[tp[x].pos]=rpt[tp[x].pos^1]=ksiz;//记录某边在线段树中的位置
// printf("rpt [%d] = rpt [%d] = %d\n",tp[x].pos,tp[x].pos^1,ksiz);
// printf("id [%d] = %d\n",x,tp[x].id);
// printf("top [%d] = %d\n",x,tp[x].top);
if(!tp[x].son)return ;
tp[tp[x].son].top=tp[x].top;
dfs2(tp[x].son,x);
for(int i=h[x];i;i=tr[i].nxt){
int y=tr[i].t;
if(y==fa)continue;
if(y==tp[x].son)continue;
tp[y].top=y;
dfs2(y,x);
}
}
struct Segment_Tree{
int l,r;
int sum;
int maxv;
int minv;
bool tag;
}seg[Maxn<<2];
inline int ls(int k){return k<<1;}
inline int rs(int k){return k<<1|1;}
inline int push_up(int k){
seg[k].sum=seg[ls(k)].sum+seg[rs(k)].sum;
seg[k].maxv=max(seg[ls(k)].maxv,seg[rs(k)].maxv);
seg[k].minv=min(seg[ls(k)].minv,seg[rs(k)].minv);
// printf("seg [%d] = %d %d %d\n",k,seg[k].sum,seg[k].maxv,seg[k].minv);
}
inline void ops(int k){
seg[k].sum*=-1;
seg[k].maxv*=-1;
seg[k].minv*=-1;
swap(seg[k].maxv,seg[k].minv);
seg[k].tag^=1;
}
inline void push_down(int k){
if(!seg[k].tag)return ;
ops(ls(k));ops(rs(k));
seg[k].tag=0;
}
void build(int k,int l,int r){
// printf("In build(%d, %d, %d)\n",k,l,r);
seg[k].l=l;seg[k].r=r;
if(l==r){
// if(l==1){
// seg[k].sum=0;
// seg[k].maxv=-Inf;
// seg[k].minv=Inf;
// return ;
// }
seg[k].sum=seg[k].maxv=seg[k].minv=rev[l];
// cout<<"In build() rev "<<rev[l]<<'\n';
return ;
}
int mid=(l+r)>>1;
build(ls(k),l,mid);
build(rs(k),mid+1,r);
push_up(k);
}
void modifit1(int k,int q,int d){
if(seg[k].l==seg[k].r){
seg[k].sum=seg[k].maxv=seg[k].minv=d;
return ;
}
push_down(k);
int mid=(seg[k].l+seg[k].r)>>1;
if(q<=mid)modifit1(ls(k),q,d);
if(q>mid)modifit1(rs(k),q,d);
push_up(k);
}
void modifit2(int k,int ql,int qr){
if(ql>qr)return ;
if(ql<=seg[k].l&&seg[k].r<=qr){
ops(k);
return ;
}
push_down(k);
int mid=(seg[k].l+seg[k].r)>>1;
if(ql<=mid)modifit2(ls(k),ql,qr);
if(qr>mid)modifit2(rs(k),ql,qr);
push_up(k);
}
int query1(int k,int ql,int qr){
if(ql>qr)return 0;
// printf("In query_sum(%d, %d, %d)\n",k,ql,qr);
if(ql<=seg[k].l&&seg[k].r<=qr)
return seg[k].sum;
push_down(k);
int ret=0;
int mid=(seg[k].l+seg[k].r)>>1;
if(ql<=mid)ret+=query1(ls(k),ql,qr);
if(qr>mid)ret+=query1(rs(k),ql,qr);
push_up(k);
return ret;
}
int query2(int k,int ql,int qr){
if(ql>qr)return -Inf;
if(ql<=seg[k].l&&seg[k].r<=qr)
return seg[k].maxv;
push_down(k);
int ret=-Inf;
int mid=(seg[k].l+seg[k].r)>>1;
if(ql<=mid)ret=max(ret,query2(ls(k),ql,qr));
if(qr>mid)ret=max(ret,query2(rs(k),ql,qr));
push_up(k);
return ret;
}
int query3(int k,int ql,int qr){
// printf("In query_min(%d, %d, %d)\n",k,ql,qr);
if(ql>qr)return Inf;
if(ql<=seg[k].l&&seg[k].r<=qr)
return seg[k].minv;
push_down(k);
int ret=Inf;
int mid=(seg[k].l+seg[k].r)>>1;
if(ql<=mid)ret=min(ret,query3(ls(k),ql,qr));
if(qr>mid)ret=min(ret,query3(rs(k),ql,qr));
push_up(k);
return ret;
}
int main(){
int n=read();
for(int i=1;i<n;i++){
int x=read()+1;
int y=read()+1;
int z=read();
Add(x,y,z);
Add(y,x,z);
}
tp[1].top=1;
dfs1(1,0);
dfs2(1,0);
// for(int i=1;i<=n;i++)
// printf("[%d] id = %d, top = %d, dep = %d\n",i,tp[i].id,tp[i].top,tp[i].dep);
build(1,1,n);
int m=read();
for(int i=1;i<=m;i++){
char str[5];scanf("%s",str);
int x=read();
int y=read();
// cout<<str<<' '<<x<<' '<<y<<'\n';
if(str[0]=='C'){modifit1(1,rpt[x<<1],y);continue;}
x++;y++;
if(str[0]=='N'){
while(tp[x].top!=tp[y].top){
if(tp[tp[x].top].dep<tp[tp[y].top].dep)swap(x,y);
modifit2(1,tp[tp[x].top].id,tp[x].id);
x=tp[tp[x].top].fa;
}
if(tp[x].dep>tp[y].dep)swap(x,y);
modifit2(1,tp[x].id+1,tp[y].id);
}else if(str[0]=='S'){
int ans=0;
while(tp[x].top!=tp[y].top){
if(tp[tp[x].top].dep<tp[tp[y].top].dep)swap(x,y);
int tmp=query1(1,tp[tp[x].top].id,tp[x].id);
ans+=tmp;
// printf("temp = %d\n",tmp);
x=tp[tp[x].top].fa;
}
if(tp[x].dep>tp[y].dep)swap(x,y);
int tmp=query1(1,tp[x].id+1,tp[y].id);
ans+=tmp;
// printf("temp = %d\n",tmp);
// printf("ans = %d \n",ans);
printf("%d\n",ans);
}else if(str[1]=='A'){
int ans=-Inf;
while(tp[x].top!=tp[y].top){
if(tp[tp[x].top].dep<tp[tp[y].top].dep)swap(x,y);
int tmp=query2(1,tp[tp[x].top].id,tp[x].id);
ans=max(ans,tmp);
// printf("temp = %d\n",tmp);
x=tp[tp[x].top].fa;
}
if(tp[x].dep>tp[y].dep)swap(x,y);
int tmp=query2(1,tp[x].id+1,tp[y].id);
ans=max(ans,tmp);
// printf("temp = %d\n",tmp);
// printf("ans = %d \n",ans);
printf("%d\n",ans);
}else if(str[1]=='I'){
int ans=Inf;
while(tp[x].top!=tp[y].top){
if(tp[tp[x].top].dep<tp[tp[y].top].dep)swap(x,y);
int tmp=query3(1,tp[tp[x].top].id,tp[x].id);
ans=min(ans,tmp);
// printf("temp = %d\n",tmp);
x=tp[tp[x].top].fa;
}
if(tp[x].dep>tp[y].dep)swap(x,y);
int tmp=query3(1,tp[x].id+1,tp[y].id);
ans=min(ans,tmp);
// printf("temp = %d\n",tmp);
// printf("ans = %d \n",ans);
printf("%d\n",ans);
}
}
return 0;
}