bzoj4034
树链剖分裸题
唯一一个要注意的地方就是数据范围吧。计算时不写(long long)是会爆掉的
#include<cstdio> #include<cctype> #define maxn 100001 using namespace std; int n,m,cnt,son[maxn],fa[maxn],siz[maxn],val[maxn],top[maxn],a[maxn],dep[maxn],id[maxn]; struct data{int l,r;long long sum,tag;}tr[maxn<<2]; int head[maxn],to[maxn<<1],nex[maxn<<1]; long long ans; void read(int &x){ char ch=getchar();x=0;int f=1; while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} x*=f; } void addedge(int u,int v){ to[++cnt]=v;nex[cnt]=head[u];head[u]=cnt; } void dfs1(int x,int f){ dep[x]=dep[f]+1;fa[x]=f;siz[x]=1; int maxson=-1; for(int i=head[x];i;i=nex[i]){ if(to[i]==f)continue; dfs1(to[i],x); siz[x]+=siz[to[i]]; if(siz[to[i]]>maxson){maxson=siz[to[i]];son[x]=to[i];} } } void dfs2(int x,int topf){ a[++cnt]=val[x];id[x]=cnt;top[x]=topf; if(!son[x])return; dfs2(son[x],topf); for(int i=head[x];i;i=nex[i]){ if(to[i]==fa[x]||to[i]==son[x])continue; dfs2(to[i],to[i]); } } void buildtr(int now,int l,int r){ tr[now].l=l;tr[now].r=r; if(l==r){tr[now].sum=a[l];return;} int mid=(l+r)>>1; buildtr(now<<1,l,mid);buildtr(now<<1|1,mid+1,r); tr[now].sum=tr[now<<1].sum+tr[now<<1|1].sum; } void pushdown(int now){ if(tr[now].l==tr[now].r||tr[now].tag==0)return; tr[now<<1].tag+=tr[now].tag;tr[now<<1|1].tag+=tr[now].tag; tr[now<<1].sum+=(tr[now<<1].r-tr[now<<1].l+1)*tr[now].tag; tr[now<<1|1].sum+=(tr[now<<1|1].r-tr[now<<1|1].l+1)*tr[now].tag; tr[now].tag=0; } void addtr(int now,int l,int r,int ad){ if(tr[now].l>=l&&tr[now].r<=r){tr[now].tag+=ad;tr[now].sum+=(long long)(tr[now].r-tr[now].l+1)*ad;return;} pushdown(now); int mid=(tr[now].l+tr[now].r)>>1; if(mid>=l)addtr(now<<1,l,r,ad); if(mid<r)addtr(now<<1|1,l,r,ad); tr[now].sum=tr[now<<1].sum+tr[now<<1|1].sum; } long long query(int now,int l,int r){ if(tr[now].l>=l&&tr[now].r<=r)return tr[now].sum; pushdown(now); long long mid=(tr[now].l+tr[now].r)>>1; if(mid>=r)return query(now<<1,l,r);else if(mid<l)return query(now<<1|1,l,r);else return query(now<<1,l,mid)+query(now<<1|1,mid+1,r); } void queryway(int x,int y){ while(top[x]!=top[y]){ ans+=query(1,id[top[x]],id[x]); x=fa[top[x]]; } ans+=query(1,id[y],id[x]); } int main(){ read(n);read(m); for(int i=1;i<=n;i++)read(val[i]); for(int i=1;i<n;i++){ int u,v;read(u);read(v); addedge(u,v);addedge(v,u); } cnt=0; dfs1(1,1);dfs2(1,1); buildtr(1,1,n); for(int i=1;i<=m;i++){ int opt,x,y; read(opt);read(x); switch(opt){ case 1:read(y);addtr(1,id[x],id[x],y);break; case 2:read(y);addtr(1,id[x],id[x]+siz[x]-1,y);break; case 3:ans=0;queryway(x,1);printf("%lld\n",ans);break; } } }