suoi08 一收一行破 (tarjanLca+树状数组)
用一个差分树状数组维护一下每个深度的和,然后每次拿着路径端点和lca加一加减一减就行了
1 #include<bits/stdc++.h> 2 #define pa pair<int,int> 3 #define ll long long 4 using namespace std; 5 const int maxn=200020; 6 7 inline ll rd(){ 8 ll x=0;char c=getchar();int neg=1; 9 while(c<'0'||c>'9'){if(c=='-') neg=-1;c=getchar();} 10 while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar(); 11 return x*neg; 12 } 13 14 int N,M,L=1; 15 int eg[maxn*2][2],egh[maxn],ect; 16 int dep[maxn],fa[maxn],v[maxn]; 17 int qu[maxn*2][2],qh[maxn],op[maxn][3]; 18 ll tr[maxn]; 19 bool flag[maxn]; 20 21 inline int lowbit(int x){return x&(-x);} 22 inline int getfa(int x){return x==fa[x]?x:fa[x]=getfa(fa[x]);} 23 inline void adeg(int a,int b){ 24 eg[++ect][0]=b;eg[ect][1]=egh[a];egh[a]=ect; 25 }inline void adq(int a,int b,int i){ 26 qu[i][0]=b;qu[i][1]=qh[a];qh[a]=i; 27 } 28 29 inline void add(int x,int y){ 30 for(;x&&x<=L;x+=lowbit(x)) tr[x]+=y; 31 } 32 inline ll query(int x){ 33 ll re=0;for(;x;x-=lowbit(x)) re+=tr[x];return re; 34 } 35 36 void tarjan(int x){ 37 flag[x]=1; 38 for(int i=egh[x];i!=-1;i=eg[i][1]){ 39 if(flag[eg[i][0]]) continue; 40 dep[eg[i][0]]=dep[x]+1;L=max(L,dep[x]+1); 41 tarjan(eg[i][0]);fa[getfa(eg[i][0])]=getfa(x); 42 } 43 for(int i=qh[x];i!=-1;i=qu[i][1]){ 44 if(flag[qu[i][0]]) op[i>>1][0]=getfa(qu[i][0]); 45 } 46 } 47 48 49 50 int main(){ 51 int i,j,k; 52 N=rd(),M=rd(); 53 memset(egh,-1,sizeof(egh));memset(qh,-1,sizeof(qh)); 54 for(i=1;i<=N;i++) v[i]=rd(); 55 for(i=1;i<N;i++){ 56 int a=rd(),b=rd(); 57 adeg(a,b);adeg(b,a); 58 }for(i=1;i<=M;i++){ 59 int a=rd(),b=rd(); 60 if(a==1){ 61 int c=rd();adq(b,c,i<<1);adq(c,b,i<<1|1); 62 op[i][1]=b,op[i][2]=c; 63 }else op[i][0]=b; 64 }for(i=1;i<=N;i++) fa[i]=i; 65 dep[1]=1;tarjan(1);//printf("ll"); 66 for(i=1;i<=N;i++){ 67 add(dep[i],v[i]);add(dep[i]+1,-v[i]); 68 } 69 for(i=1;i<=M;i++){ 70 if(op[i][1]){ 71 add(dep[op[i][0]],1);add(dep[op[i][0]]+1,-1); 72 if(op[i][0]!=op[i][1]) add(dep[op[i][0]]+1,1),add(dep[op[i][1]]+1,-1); 73 if(op[i][0]!=op[i][2]) add(dep[op[i][0]]+1,1),add(dep[op[i][2]]+1,-1); 74 }else{ 75 if(op[i][0]>L) printf("0\n"); 76 else printf("%lld\n",query(op[i][0])); 77 } 78 } 79 80 return 0; 81 }