Codeforces 165D Beard Graph 边权树剖+树状数组
题意:给你一颗由n个结点组成的树,支持以下操作:1 i:将第i条边染成黑色(保证此时该边是白色),2 i:将第i条边染成白色(保证此时该边是黑色),3 a b:找出a,b两点之间只由黑边组成的最短路径.
思路:树链剖分+树状数组,把每条边的权值放到它指向的点中去,初始全为黑边,黑边权值为1,白边权值为-inf,黑边变白边,将点权增加-inf,白边变黑边点权增加inf,因为不可能白边变白边,所以可以这样做,查询的时候要减去2个点的最近公共祖先的点权,最近公共祖先可通过树剖的get函数求,查询结果<0说明不可能到达
AC代码:
#include "iostream" #include "string.h" #include "stack" #include "queue" #include "string" #include "vector" #include "set" #include "map" #include "algorithm" #include "stdio.h" #include "math.h" #define ll long long #define bug(x) cout<<x<<" "<<"UUUUU"<<endl; #define mem(a,x) memset(a,x,sizeof(a)) #define mp(x,y) make_pair(x,y) #define pb(x) push_back(x) using namespace std; const long long INF = 1e18+1LL; const int inf = 1e9+1e8; const int N=1e5+100; const ll mod=1e9+7; int to[N<<1],nex[N<<1],head[N],tot=2; int siz[N],son[N],tip[N],top[N],dep[N],fa[N],cnt=0; int n,m; ll C[N]; map<int,int> M; void add(int u, int v){ to[tot]=v; nex[tot]=head[u]; head[u]=tot++; } void dfs1(int u, int f){ siz[u]=1; fa[u]=f; dep[u]=dep[f]+1; for(int i=head[u]; i!=-1; i=nex[i]){ int v=to[i]; if(v==f) continue; M[i>>1]=v; dfs1(v,u); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void dfs2(int u, int tp){ tip[u]=++cnt; top[u]=tp; if(son[u]) dfs2(son[u],tp); for(int i=head[u]; i!=-1; i=nex[i]){ int v=to[i]; if(v!=fa[u] && v!=son[u]) dfs2(v,v); } } int lowbit(int x){ return (-x)&x; } void up(int x, ll c){ while(x<=n){ C[x]+=c; x+=lowbit(x); } } ll sum(int l, int r){ ll ret=0; l--; while(l>0){ ret-=C[l]; l-=lowbit(l); } while(r>0){ ret+=C[r]; r-=lowbit(r); } return ret; } void get_sum(int u, int v){ ll ans=0; while(top[u]!=top[v]){ if(dep[top[u]] < dep[top[v]]) swap(u,v); ans+=sum(tip[top[u]], tip[u]); u=fa[top[u]]; } if(dep[u] > dep[v]) swap(u,v); //u为LCA ans+=sum(tip[u],tip[v]); ans-=sum(tip[u],tip[u]); if(ans<0) cout<<"-1\n"; else cout<<ans<<"\n"; } int main(){ ios::sync_with_stdio(false),cin.tie(0),cout.tie(0); cin>>n; mem(head,-1); int c,u,v; for(int i=1; i<n; ++i){ cin>>u>>v; add(u,v); add(v,u); } dfs1(1,1); dfs2(1,1); for(int i=2; i<=n; ++i){ up(i,1); } cin>>m; while(m--){ cin>>c>>u; if(c==3){ cin>>v; get_sum(u,v); } else if(c==2) up(tip[M[u]],-inf); else{ //ll t=sum(1,tip[M[u]])-sum(1,tip[M[u]]-1); up(tip[M[u]],inf); } } return 0; }