我TM再也不写BIT套主席树了。。。。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define maxv 200500 #define maxe 200500 using namespace std; int n,m,q,type,x,y,g[maxv],nume=0,dfn[maxv],mx[maxv],fdfn[maxv],dis[maxv],anc[maxv][23],times=0; int root[maxv*2],sum[maxv*180],tot=0,ls[maxv*180],rs[maxv*180]; struct edge { int v,nxt; }e[maxe]; void addedge(int u,int v) { e[++nume].v=v; e[nume].nxt=g[u]; g[u]=nume; } int lowbit(int x) { return (x&(-x)); } void dfs(int x,int fath) { dfn[x]=++times;fdfn[times]=x;mx[x]=dfn[x]; for (int i=g[x];i;i=e[i].nxt) { int v=e[i].v; if (v!=fath) { anc[v][0]=x;dis[v]=dis[x]+1; dfs(v,x); mx[x]=max(mx[x],mx[v]); } } } void build(int &now,int left,int right) { now=++tot;sum[now]=0; if (left==right) return; int mid=(left+right)>>1; build(ls[now],left,mid); build(rs[now],mid+1,right); } void get_table() { for (int e=1;e<=20;e++) for (int i=1;i<=n;i++) anc[i][e]=anc[anc[i][e-1]][e-1]; } int lca(int x,int y) { if (dis[x]<dis[y]) swap(x,y); if (dis[x]!=dis[y]) { for (int e=20;e>=0;e--) { if ((dis[anc[x][e]]>=dis[y]) && (anc[x][e])) x=anc[x][e]; } } if (x==y) return x; for (int e=20;e>=0;e--) { if (anc[x][e]!=anc[y][e]) { x=anc[x][e]; y=anc[y][e]; } } return anc[x][0]; } int find(int x,int pos) { for (int e=20;e>=0;e--) if (dis[anc[x][e]]>dis[pos]) x=anc[x][e]; return x; } void modify(int last,int x,int pos,int val) { root[x]=++tot; int now=root[x];sum[now]=sum[last]+val; int left=1,right=n; while (left<right) { int mid=(left+right)>>1; ls[now]=ls[last];rs[now]=rs[last]; if (pos<=mid) { ls[now]=++tot; sum[ls[now]]=sum[ls[last]]+val; now=ls[now];last=ls[last];right=mid; } else { rs[now]=++tot; sum[rs[now]]=sum[rs[last]]+val; now=rs[now];last=rs[last];left=mid+1; } } } void add(int x,int pos,int val) { for (int i=x;i<=n;i+=lowbit(i)) modify(root[i],i,pos,val); } void build_tree() { build(root[0],1,n); for (int i=1;i<=n;i++) { root[i]=++tot; ls[root[i]]=ls[root[i-1]];rs[root[i]]=rs[root[i-1]]; } scanf("%d",&m); for (int i=1;i<=m;i++) { scanf("%d%d",&x,&y); add(dfn[x],dfn[y],1); add(dfn[y],dfn[x],1); } } int ask(int now,int left,int right,int l,int r) { if ((left==l) && (right==r)) return sum[now]; int mid=(left+right)>>1; if (r<=mid) return ask(ls[now],left,mid,l,r); else if (l>=mid+1) return ask(rs[now],mid+1,right,l,r); else return ask(ls[now],left,mid,l,mid)+ask(rs[now],mid+1,right,mid+1,r); } int ask_(int now,int left,int right) { int ret=0; for (int i=now;i>=1;i-=lowbit(i)) ret+=ask(root[i],1,n,left,right); return ret; } int main() { scanf("%d",&n); for (int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); addedge(x,y);addedge(y,x); } dfs(1,1); get_table(); build_tree(); scanf("%d",&q); for (int i=1;i<=q;i++) { scanf("%d%d%d",&type,&x,&y); if (type==1) { add(dfn[x],dfn[y],1); add(dfn[y],dfn[x],1); } else if (type==2) { add(dfn[x],dfn[y],-1); add(dfn[y],dfn[x],-1); } else { int t=lca(x,y); if (t==y) { int r=find(x,y); int ret1=ask_(mx[x],1,n)-ask_(dfn[x]-1,1,n); int ret2=ask_(mx[x],dfn[r],mx[r])-ask_(dfn[x]-1,dfn[r],mx[r]); printf("%d\n",ret1-ret2); } else if (t==x) { int r=find(y,x); int ret1=ask_(n,dfn[y],mx[y]); int ret2=ask_(mx[r],dfn[y],mx[y])-ask_(dfn[r]-1,dfn[y],mx[y]); printf("%d\n",ret1-ret2); } else printf("%d\n",ask_(mx[x],dfn[y],mx[y])-ask_(dfn[x]-1,dfn[y],mx[y])); } } return 0; }