bzoj 3924
动态点分治好题
首先我们考虑一个暴力做法:
每次修改之后选一个点作为根搜索整棵树,然后换根dp即可
考虑每次换根时,移向的点的消耗会减少子树代价之和*边权,而其余部分代价会增加剩余代价*边权
这样每次换根都是$O(1)$的,总时间复杂度$O(nm)$,可以通过...20分!
贴代码:
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> #define ll long long using namespace std; int n,m; struct Edge { int next; int to; int val; }edge[10005]; int head[5005]; ll num[5005]; int son[5005]; int cnt=1; int s=0; void init() { memset(head,-1,sizeof(head)); cnt=1; } void add(int l,int r,int w) { edge[cnt].next=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } ll ans=0; void dfs1(int x,int fx,ll d) { ans+=(ll)(num[x]*d); son[x]+=num[x]; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fx) { continue; } dfs1(to,x,(ll)d+edge[i].val); son[x]+=son[to]; } } void dfs2(int x,int fx,ll tot) { for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fx) { continue; } ans=min(ans,(ll)tot+(ll)(s-2*son[to])*edge[i].val); dfs2(to,x,tot+(ll)(s-2*son[to])*edge[i].val); } } int main() { scanf("%d%d",&n,&m); init(); for(int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } for(int i=1;i<=m;i++) { int x,y; scanf("%d%d",&x,&y); num[x]+=y; s+=y; ans=0; memset(son,0,sizeof(son)); dfs1(1,1,0); dfs2(1,1,ans); printf("%lld\n",ans); } return 0; }
然后我们考虑正解
我们发现:这个换根的过程有两个可以优化的地方:
首先,并不是所有点都可能作为根的,有一些点显然不会作为根,而这一部分其实应该剪掉!
第二:每次修改只是重构了这棵树一条树链上的信息,因此我们每次修改就重新搜索整棵树是不合算的
因此我们就可以考虑用点分治来优化这个过程了
当然了,由于涉及修改,所以肯定是动态点分治
用上点分治之后,第一点可以用类似这道题的方法来解决,即向各个方向走,找一个最优的方向去走即可
第二点就可以直接在点分树上暴力走了,这样就结束了
然后我们可以考虑如何修改和查询了
对每个节点维护三个值,分别维护自己子树内的权值,自己子树内的总权值贡献和自己子树中权值对父节点的贡献
注意区分权值与权值贡献:权值贡献考虑边权!
修改的时候直接暴力修改即可
假设我们想查询以某个点为根的贡献,那么我们就需要统计两部分:一部分是自己子树内的,另一部分是自己子树外的
自己子树内的可以直接加
自己子树外的我们可以沿着点分树向上跳,跳到的每个节点在子树中要去掉来源点那部分子树的贡献,然后加上当前点的权值贡献即可
然后每次查询的时候先任意指定一个点然后向四周走,如果走到的点比当前点更优那么就向那个方向走
有个小trick:在点分树上每个点深度都是对数级别,因此我们可以直接预处理出每个节点与它的几层父亲的距离,查询时直接调用就可以了
贴代码:
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> using namespace std; typedef long long ll; struct Edge { int nxt,to; ll val; }edge[200005]; int head[100005]; int pre[100005]; ll dis[100005]; int dep[100005],soz[100005],son[100005]; int siz[100005],maxp[100005]; int vis[100005],ttop[100005],f[100005]; ll w1[100005],w2[100005],w3[100005]; ll d[100005][25]; int rt,s,rot; int n,m; int cnt=1; void add(int l,int r,ll w) { edge[cnt].nxt=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } void get_rt(int x,int fx) { siz[x]=1,maxp[x]=0; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(to==fx||vis[to])continue; get_rt(to,x); siz[x]+=siz[to],maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void solve(int x) { vis[x]=1; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(vis[to])continue; rt=0,s=siz[to]; get_rt(to,x); pre[rt]=x,solve(rt); } } void dfs(int x,int fx,int deep) { soz[x]=1,dis[x]=deep,dep[x]=dep[fx]+1,f[x]=fx; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(to==fx)continue; dfs(to,x,deep+edge[i].val); soz[x]+=soz[to],son[x]=(soz[son[x]]<soz[to])?to:son[x]; } } void redfs(int x,int topx,int fx) { ttop[x]=topx; if(son[x])redfs(son[x],topx,x); for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; if(to==fx||to==son[x])continue; redfs(to,to,x); } } int LCA(int x,int y) { while(ttop[x]!=ttop[y]) { if(dep[ttop[x]]<dep[ttop[y]])swap(x,y); x=f[ttop[x]]; } return dep[x]<dep[y]?x:y; } ll get_dis(int x,int y) { return dis[x]+dis[y]-2*dis[LCA(x,y)]; } void update(int x,ll v) { for(int i=x,j=0,k=0;i;j=i,i=pre[i],k++) { ll temp=d[x][k]*v; w1[i]+=temp,w3[i]+=v; if(j)w2[j]+=temp; } } ll query(int x) { ll ret=0; for(int i=x,j=0,k=0;i;j=i,i=pre[i],k++)ret+=w1[i]-w2[j]+(w3[i]-w3[j])*d[x][k]; return ret; } int divi(int x) { ll ori=query(x); int ret=-1; for(int i=head[x];i;i=edge[i].nxt) { int to=edge[i].to; ll temp=query(to); if(temp<ori)ret=to,ori=temp; } return ret; } template <typename T>inline void read(T &x) { T f=1,c=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();} x=c*f; } int main() { read(n),read(m); for(int i=1;i<n;i++) { int x,y; ll z; read(x),read(y),read(z); add(x,y,z),add(y,x,z); } dfs(1,1,0),redfs(1,1,1); rt=0,maxp[0]=0x3f3f3f3f,s=n; get_rt(1,0),rot=rt,solve(rt); for(int i=1;i<=n;i++)for(int k=1,j=pre[i];j;k++,j=pre[j])d[i][k]=get_dis(i,j); int x;ll va; read(x),read(va),m--; update(x,va),printf("0\n"); while(m--) { read(x),read(va); update(x,va); int ans=rot; int t=divi(rot); while(t!=-1)ans=t,t=divi(t); printf("%lld\n",query(ans)); } return 0; }