luogu P4338 [ZJOI2018]历史
题面传送门
这道题给了没有修改的部分分,考虑怎么拿到这些分数。
然后发现我们可以枚举每个节点然后算贡献。
可以发现如果两个不同子树的access就会产生一个贡献。
然而如果一个子树过大那么就没有那么多贡献。
从而我们可以一次树形dp解决掉。
然而这个东西带修就不是很好做了。
可以发现这个东西和树链剖分的轻重链划分很像,所以我们照样维护LCT的虚边和实边,然后照样access,维护的时候讨论一下即可。
code:
#include<cstdio>
#define I inline
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define beg(x) int cur=s.h[x]
#define end cur
#define go cur=tmp.z
#define l(x) x<<1
#define r(x) x<<1|1
#define N 400039
#define ll long long
using namespace std;
int n,m,k,x,y,z,a[N],d[N];ll ans;
struct yyy{int to,z;};
struct ljb{
int head,h[N];yyy f[2*N];
I void add(int x,int y){f[++head]=(yyy){y,h[x]};h[x]=head;}
}s;
struct linkcuttree{
int fa[N],l[N],r[N];ll sum[N],val[N],siz[N];
I ll calc(int x,ll sum,ll w){return r[x]?(sum-w)*2:((val[x]*2>sum)?(sum-val[x])*2:sum-1);}
I void up(int x){sum[x]=sum[l[x]]+sum[r[x]]+val[x]+siz[x];}
I void dfs(int x,int last){
yyy tmp;sum[x]=val[x];int mx=0;fa[x]=last;
for(beg(x);end;go) tmp=s.f[cur],(tmp.to^last)&&(dfs(tmp.to,x),sum[x]+=sum[tmp.to],mx=sum[tmp.to]>sum[mx]?tmp.to:mx);
if(sum[mx]*2>sum[x]) r[x]=mx;ans+=calc(x,sum[x],sum[mx]);siz[x]=sum[x]-val[x]-sum[r[x]];
}
I void swap(int &x,int &y){x^=y^=x^=y;}
I int child(int x){return l[fa[x]]==x||r[fa[x]]==x;}
I int wrt(int x){return l[fa[x]]==x;}
I void rotate(int x){
int y=fa[x],z=fa[y],b=(x==l[y]?r[x]:l[x]);child(y)&&((y==l[z]?l[z]:r[z])=x);
(x==l[y])?(r[x]=y,l[y]=b):(l[x]=y,r[y]=b);fa[x]=z;fa[y]=x;b&&(fa[b]=y);up(y);up(x);
}
I void splay(int x){while(child(x)) child(fa[x])&&(rotate(wrt(x)^wrt(fa[x])?x:fa[x]),0),rotate(x);}
I void access(int x,int y){
splay(x);ll now=sum[x]-sum[l[x]],w=sum[r[x]];ans-=calc(x,now,w);
sum[x]+=y,val[x]+=y;now+=y;
if(w*2<=now) siz[x]+=w,r[x]=0;
ans+=calc(x,now,w);up(x);int z;
for(x=fa[z=x];x;x=fa[z=x]){
splay(x);now=sum[x]-sum[l[x]],w=sum[r[x]];ans-=calc(x,now,w);
sum[x]+=y,siz[x]+=y;now+=y;
if(w*2<=now) siz[x]+=w,w=r[x]=0;
if(sum[z]*2>now)siz[x]-=sum[z],r[x]=z,w=sum[z];
ans+=calc(x,now,w);up(x);
}
}
}g;
int main(){
freopen("1.in","r",stdin);
register int i;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++) scanf("%d",&g.val[i]);
for(i=1;i<n;i++) scanf("%d%d",&x,&y),s.add(x,y),s.add(y,x);g.dfs(1,0);
printf("%lld\n",ans);for(i=1;i<=m;i++) scanf("%d%d",&x,&y),g.access(x,y),printf("%lld\n",ans);
}