bzoj3730 震波 解题报告 (动态点分治)
由于一个小错误, 花了我一个上午的时间.....
题意
一棵 \(n\) 个点的树 ($ 1 \le n \le 10^5$), 每个点有一个点权 \(w[i]\) (\(1 \le w[i] \le 10^4\)).
有 \(m\) 个询问, 询问有两种,
- 将点 \(x\) 的权值改为 \(y\).
- 询问与点 \(x\) 距离不超过 \(k\) 的点的权值和.
思路
一
若不用修改, 且只有一个询问, 我们可以点分治, 对每个点 \(u\) 找出所有满足 \(dist(u,v) \le k - dist(u,x)\), 且与 \(x\) 不在同一子树的点 \(v\), 计算它们的权值和, 找到答案.
二
假设依旧不用修改, 但有多个询问, 考虑一下有什么数据是可以反复使用的.
发现由于树的结构与权值都不变, 那么与任一点 \(u\) 相距 \(t\) 的点的权值和是不变的, (\(u \in [1,n], t \in [0,n-1]\)).
所以, 我们可以对每个点 \(u\) 开一个 \(vector\), 设为 \(v1\), 存储在 \(u\) 的点分树子树中与它距离为 \(t\) 的点权和, 并对 \(v1\) 建立树状数组维护前缀和.
询问时, 在**点分树**上依次往上跳,
设当前点 (不是询问点) 为 \(u\), 用倍增求 \(lca\) 的方法找出 \(u\) 在点分树上的父亲 \(ft[u]\) 与**询问点 \(x\) **的距离 \(len\) ,在 \(v1[ft[u]]\) 上查询 \(k-len\) 的前缀和.
但仔细思考一下, 发现这样会重复计算与 \(x\) 在同一子树内的点的权值,
所以, 我们需要再开一个 \(vector\), 设为 \(v2\),
\(v2[u][t]\) 表示在 \(u\) 的在点分树的子孙中到 \(ft[u]\) 的距离为 \(t\) 的点的点权和, 类似地, 也对它建立树状数组维护前缀和.
那么, 我们只需要在询问时在点分树上逐层往上跳,
对于每一层的点 \(u\) , \(res+= sum(v1[ft[u]],k-len) - sum(v2[u],k-len)\) (\(sum\) 表示在树状数组上查询到前缀和).
三
现在, 考虑本题的实际情况 : 要修改, 多个询问.
其实考虑到了如何处理多个询问后, 修改也就不难了.
和询问类似, 修改时也是在点分树上逐层往上跳, 对每个当前点 \(u\) 算出 \(len = dist(ft[u],x)\),
然后 $ add(v1[ft[u]],\ len,\ y-val[x]), add(v2[u],\ len,\ y-val[x])$ ( \(add\) 表示树状数组的修改操作, \(y\) 表示要修改为的权值).
这道题就解完了.
还有一点要注意的是, 修改和查询时, 都不要忘记处理点 \(x\) 本身.
代码
#include<bits/stdc++.h>
#define uint unsigned int
#define pb push_back
#define sz size
using namespace std;
const int _=1e6+7;
const int __=1e6+7;
const int L=20;
const int inf=0x3f3f3f3f;
bool be;
int n,m,val[_],dis[_],dep[_],f[_][L+7],sz[_],rt,minx=inf,dpt,q[_],top,ft[_];
int lst[_],nxt[__],to[__],tot;
bool vis[_];
vector<int> v1[_],v2[_]; // v1: to self v2: to father
bool en;
void add(int x,int y){ nxt[++tot]=lst[x]; to[tot]=y; lst[x]=tot; }
void dbing(int u,int fa){
dep[u]=dep[fa]+1;
f[u][0]=fa;
for(int i=1;i<=L;i++)
f[u][i]=f[f[u][i-1]][i-1];
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
dbing(v,u);
}
}
void add(int u,int x,int v,int id){
if(id==1) for(int i=x;i<(int)v1[u].sz();i+=i&(-i)) v1[u][i]+=v;
else for(int i=x;i<(int)v2[u].sz();i+=i&(-i)) v2[u][i]+=v;
}
int summ(int u,int x,int id){
int res=0;
if(id==1) for(int i=min(x,(int)v1[u].sz()-1);i>0;i-=i&(-i)) res+=v1[u][i];
else for(int i=min(x,(int)v2[u].sz()-1);i>0;i-=i&(-i)) res+=v2[u][i];
return res;
}
void g_rt(int u,int fa,int sum){
int maxn=0; sz[u]=1; dis[u]=dis[fa]+1;
dpt=max(dpt,dis[u]); q[++top]=u;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
g_rt(v,u,sum);
sz[u]+=sz[v];
maxn=max(maxn,sz[v]);
}
maxn=max(maxn,sum-sz[u]);
if(maxn<minx){ minx=maxn; rt=u; }
}
void cnt(int u,int fa,int rt){
dis[u]=dis[fa]+1; dpt=max(dpt,dis[u]); q[++top]=u;
for(int i=lst[u];i;i=nxt[i])
if(!vis[to[i]]&&to[i]!=fa) cnt(to[i],u,rt);
}
void calc(int u){
dpt=top=0;
for(int i=lst[u];i;i=nxt[i])
if(!vis[to[i]]) cnt(to[i],0,u);
v1[u].resize(dpt+2);
for(int i=1;i<=top;i++) add(u,dis[q[i]]+1,val[q[i]],1);
add(u,1,val[u],1);
}
void init(int u,int lrt,int sum){
minx=inf; dpt=top=0;
g_rt(u,0,sz[u]<sz[lrt] ?sz[u] :sum-sz[lrt]);
sum=sz[u];
v2[rt].resize(dpt+2);
for(int i=1;i<=top;i++) add(rt,dis[q[i]]+1,val[q[i]],2);
ft[rt]=lrt; vis[rt]=1; u=rt;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
init(v,u,sum);
}
calc(u);
vis[u]=0;
}
int Lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=L;i>=0;i--)
if(dep[f[x][i]]>=dep[y])
x=f[x][i];
if(x==y) return x;
for(int i=L;i>=0;i--)
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
int dist(int x,int y){ return dep[x]+dep[y]-2*dep[Lca(x,y)]; }
int query(int x,int k){
int len,fa=ft[x],u=x,res=summ(u,k+1,1);
while(fa){
len=k-dist(fa,x);
res+=summ(fa,len+1,1)-summ(u,len+1,2);
u=ft[u]; fa=ft[u];
}
return res;
}
void modify(int x,int v){
int len,fa=ft[x],u=x;
add(u,1,v-val[x],1);
while(fa){
len=dist(fa,x);
add(fa,len+1,v-val[x],1);
add(u,len+1,v-val[x],2);
//printf("u: %d fa: %d len: %d %d \n",u,fa,len,v-val[x]);
u=ft[u]; fa=ft[u];
}
val[x]=v;
}
void run(){
int ty,x,y,lst=0;
for(int i=1;i<=m;i++){
scanf("%d%d%d",&ty,&x,&y);
x^=lst; y^=lst;
if(!ty){ lst=query(x,y); printf("%d\n",lst); }
else modify(x,y);
}
}
int main(){
//freopen("x.in","r",stdin);
//freopen("new.out","w",stdout);
cin>>n>>m;
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
int x,y;
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dbing(1,0);
init(1,0,n);
run();
return 0;
}