又一次把lct写炸了,硬着头皮终于改对了
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> #include<algorithm> using namespace std; const int maxn=100010; struct node{ int fa,ls,rs,is_root; }tr[maxn*2]; int t,tot[maxn*2],sum[maxn*2],cnt,last[maxn*2],pre[maxn*2],to[maxn*2];//tot表示这个点连出的虚边子树和; int n,q,col[maxn],f[maxn]; void add(int x,int y){++t;pre[t]=last[x];last[x]=t;to[t]=y;} void update(int x){ sum[x]=tot[x]; if(tr[x].ls!=0&&tr[x].ls!=n+1)sum[x]+=sum[tr[x].ls]; if(tr[x].rs!=0&&tr[x].rs!=n+1)sum[x]+=sum[tr[x].rs]; } void rx(int x){ int y=tr[x].fa,z=tr[y].fa; tr[y].ls=tr[x].rs; if(tr[x].rs!=0&&tr[x].rs!=n+1)tr[tr[x].rs].fa=y; tr[x].rs=y;tr[y].fa=x; tr[x].fa=z; if(z!=0&&z!=n+1&&!tr[y].is_root){ if(tr[z].ls==y)tr[z].ls=x;else tr[z].rs=x; } if(tr[y].is_root)tr[y].is_root=0,tr[x].is_root=1; update(y);update(x); } void lx(int x){ int y=tr[x].fa,z=tr[y].fa; tr[y].rs=tr[x].ls; if(tr[x].ls!=0&&tr[x].rs!=n+1)tr[tr[x].ls].fa=y; tr[x].ls=y;tr[y].fa=x; tr[x].fa=z; if(z&&z!=n+1&&!tr[y].is_root){ if(tr[z].ls==y)tr[z].ls=x;else tr[z].rs=x; } if(tr[y].is_root)tr[y].is_root=0,tr[x].is_root=1; update(y);update(x); } void splay(int x){ while(!tr[x].is_root){ int y=tr[x].fa,z=tr[y].fa; if(tr[y].is_root){if(tr[y].ls==x)rx(x);else lx(x);} else{ if(tr[z].ls==y&&tr[y].ls==x){rx(y);rx(x);} else if(tr[z].ls==y&&tr[y].rs==x){lx(x);rx(x);} else if(tr[z].rs==y&&tr[y].ls==x){rx(x);lx(x);} else {lx(y);lx(x);} } } } void ace(int x){ for(int p=0;x!=0&&x!=n+1;x=tr[x].fa){ splay(x); if(tr[x].rs!=0&&tr[x].rs!=n+1){ tr[tr[x].rs].is_root=1; tot[x]+=sum[tr[x].rs]; } if(p!=0&&p!=n+1){ tot[x]-=sum[p]; } tr[tr[x].rs=p].is_root=0; update(p=x); } } void link(int x,int y){//x是y的父亲 if(x==0||x==n+1)return; ace(x);splay(x);splay(y); tr[y].fa=x;tr[x].rs=y;tr[y].is_root=0;//一开始最后这句丢了; update(x); } void cut(int x,int y){//y是x的父亲 if(y==0||y==n+1)return; ace(x);splay(x);tr[tr[x].ls].fa=0;tr[tr[x].ls].is_root=1;tr[x].ls=0;update(x); } void dfs(int x,int fa){ for(int i=last[x];i;i=pre[i]){ int v=to[i]; if(v==fa)continue; link(x,v);f[v]=x; dfs(v,x); } } int query(int x){ int tmp1=x,tmp2; if(col[x])x+=n+1; ace(x); splay(x); while(tr[x].ls){ x=tr[x].ls; } splay(x); if(col[tmp1])tmp2=x-n-1; else tmp2=x; if(col[tmp2]!=col[tmp1])return sum[tr[x].rs]; else {return sum[x];} } int main(){ int x,y,op; cin>>n; for(int i=1;i<n;++i){ scanf("%d %d",&x,&y); add(x,y);add(y,x); } for(int i=1;i<=n;++i){ sum[i]=tot[i]=1; tr[i].is_root=1; } for(int i=n+2;i<=2*n+2;++i)tr[i].is_root=1; dfs(1,0); cin>>q; for(int i=1;i<=q;++i){ scanf("%d %d",&op,&x); if(op){ if(col[x]){ cut(x+n+1,f[x]+n+1); tot[x+n+1]-=1;sum[x+n+1]-=1; tot[x]+=1;sum[x]+=1; link(f[x],x); } else{ cut(x,f[x]); tot[x]-=1;sum[x]-=1; tot[x+n+1]+=1;sum[x+n+1]+=1; link(f[x]+n+1,x+n+1); } col[x]^=1; } else{ printf("%d\n",query(x)); } } return 0; }