<学习笔记> 点分树

感觉可以理解为带修点分治。

常用于解决与树原形态无关的带修改问题。 —— oi-wiki

点分树是通过更改原树形态使树的层数变为稳定 logn 的一种重构树。就是通过点分治找重心的方式,将这一层重心为上一层重心的儿子。

所以对于很多暴力的复杂度是正确的。

一开始发现建树错了,然后发现是原先的点分治错了,然后才知道其实错误的求重心也可能会有正确的复杂度

2023NOIP A层联测21 距离#

子任务二可以直接拿点分树做,就是对于每个重心维护到这个点距离最近的点的距离。然后查询时就将重心看作 lca,直接暴跳就可以了。

点击查看代码
void find(int x,int f){
    siz[x]=1;
    mx[x]=0;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==f || vis[y]) continue; 
        find(y,x);
        siz[x]+=siz[y]; 
        mx[x]=max(mx[x],siz[y]);
    }
    mx[x]=max(mx[x],S-siz[x]);
    if(mx[root]>mx[x]) root=x;
}
void solve(int x){
    vis[x]=1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        S=siz[y],root=0;
        find(y,x);
        fat[root]=x;
        solve(root);
    }
}
int mxxx[N];
void updata(int x){
    for(int i=x;i;i=fat[i]){
        int dis=ask_dist(i,x);
        mxxx[i]=min(mxxx[i],dis);
    }
}
int query(int x){
    int ans=inf;
    for(int i=x;i;i=fat[i]){
        int dis=ask_dist(x,i);
        if(mxxx[i]>(1ll<<50)) continue;
        ans=min(ans,dis+mxxx[i]);
    }
    return ans;
}

然后考虑两对点,做法就是点分治套点分树,对 x,a 进行点分治,然后将在这一层的操作处理掉,按顺序对 y,b 进行点分树上的修改和查询。具体就是先建一棵点分树,然后将每个操作压入所有涉及到的分治重心上,这里直接在点分树上跳就可以,可以发现每个操作最多被处理 logn 次。复杂度就是 O(nlog2n)

震波#

虽说是道板子,但夹带的东西有点多。

还是建点分树,然后对每个点分树维护当前层到这个点距离为 k 的贡献,这里数组直接开开不下,所以只需要开一个当前层大小的 vector 就可以,然后 C[0][x].resize(size[x]+4) 就很好用。

注意,因为枚举的重心当作 lca,所以查询 x,k ,对于重心 y,你要查的是 y 子树内距离为 kdistx,y 的贡献,但是对于在 y 的在 x 方向上的子树这样查就不可以,所以要排除掉。所以对每个点开两个 vector,一个存对当先点的贡献,一个对存当前点父亲的贡献,直接查的时候,用 y 父亲的减去 y 的算的是父亲中的贡献。

code
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int val[N];
bool vis[N];
int size[N],mx[N],S,root;
int head[N*2],ver[N*2],nex[N*2],d[N],f[N][50],tot=0,fa[N];
int n,m;
int t;
void add(int x,int y){
    ver[++tot]=y,nex[tot]=head[x],head[x]=tot;
}
vector<int> C[3][N];
void dfs(int x,int fat){
    f[x][0]=fat;
    d[x]=d[fat]+1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat) continue;
        dfs(y,x);
    }
}
void init(){
    t=log2(n)+1;
    for(int j=1;j<=t;j++)
        for(int i=1;i<=n;i++)
            f[i][j]=f[f[i][j-1]][j-1];
}
int ask_lca(int x,int y){
    if(d[x]>d[y]) swap(x,y);
    for(int i=t;i>=0;i--)
        if(d[f[y][i]]>=d[x]) y=f[y][i];
    if(x==y) return x;
    for(int i=t;i>=0;i--){
        if(f[x][i]!=f[y][i]){
            x=f[x][i],y=f[y][i];
        } 
    }
    return f[x][0];
}
int ask_dist(int x,int y){
    int lca=ask_lca(x,y);
    return d[x]+d[y]-2*d[lca];
}
inline int lowbit(int x){
    return x&(-x);
}
void change(int op,int x,int c,int u){
    x++;
    int lim=C[op][u].size();
    for(;x<lim;x+=lowbit(x)) C[op][u][x]+=c;
}
int query(int op,int x,int u){
    x++;
    x=min(x,(int)C[op][u].size()-1);
    int ans=0;
    for(;x;x-=lowbit(x)) ans+=C[op][u][x];
    return ans;
}
void find(int x,int fat){
    size[x]=1;
    mx[x]=0;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat || vis[y]) continue;
        find(y,x);
        size[x]+=size[y];
        mx[x]=max(mx[x],size[y]);
    }
    mx[x]=max(mx[x],S-size[x]);
    if(mx[x]<mx[root]) root=x;
}
void work(int x,int fat){
    S++;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat || vis[y]) continue;
        work(y,x);
    }
}
void solve(int x){
    vis[x]=1;
    C[0][x].resize(S+2);
    C[1][x].resize(S+2);
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        S=size[y];
        root=0;
        find(y,x);
        fa[root]=x;
        solve(root);
    }
}
void updata(int x,int v){
    for(int i=x;i;i=fa[i]) change(0,ask_dist(x,i),v,i);
    for(int i=x;fa[i];i=fa[i]) change(1,ask_dist(x,fa[i]),v,i);
}
signed main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&val[i]);
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    dfs(1,0);
    init();
    root=0;
    mx[0]=n+1;
    S=n;
    find(1,0);
    solve(root);
    for(int i=1;i<=n;i++) updata(i,val[i]);
    int las=0;
    for(int i=1;i<=m;i++){
        int op,x,y;
        scanf("%d%d%d",&op,&x,&y);
        x^=las,y^=las;
        if(op) updata(x,y-val[x]),val[x]=y;
        else{
            int res=query(0,y,x);
            for(int i=x;fa[i];i=fa[i]){
                int dis=ask_dist(x,fa[i]);
                if(y>=dis) res+=((query(0,(y-dis),fa[i])-query(1,(y-dis),i)));
            }
            las=res;
            printf("%d\n",res);
        }
    }
}

永雏塔菲#

正解并不是点分树,但是可以点分树这么做。

考虑每一个分值层,从分治中心开始求一遍 dfs 序,以 dfs 序为下标建一棵存分治中心到所有点距离的线段树。每个分治中心维护一个 set 里面存它的每个儿子贡献的最长链,那么这个分治中心贡献的答案就是自己加上最长的两条链或0。维护一个全局可删堆维护每个分治节点的答案,那么 1 就是全局可删堆第一个; 2 可以先对那个点加一个极大值查询后再减去;3 跳重构树,对于每个分治节点就是在线段树上修改一段区间,在更新答案。最重要的是暴力卡常。

code
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10;
const int inf=(1ll<<50);
template<typename T, typename _Compare>
class DeletableHeap {
public:
    typedef std::priority_queue<T, std::vector<T>, _Compare> Heap;

private:
    Heap heap, deleted;

public:
    DeletableHeap() = default;

    void insert(const T &value) {
        heap.push(value);
    }

    void erase(const T &value) {
        deleted.push(value);
    }

    void pop() {
        while (!deleted.empty() && heap.top() == deleted.top()) {
            heap.pop();
            deleted.pop();
        }

        heap.pop();
    }

    T top() {
        while (!deleted.empty() && heap.top() == deleted.top()) {
            heap.pop();
            deleted.pop();
        }

        return heap.top();
    }

    size_t size() {
        return heap.size() - deleted.size();
    }

    bool empty() {
        return size() == 0;
    }

    void clear() {
        heap.clear();
        deleted.clear();
    }

    void swap(DeletableHeap &other) {
        heap.swap(other.heap);
        deleted.swap(other.deleted);
    }

    template<typename... _Args>
    void emplace(_Args &&...__args) {
        heap.emplace(std::forward<_Args>(__args)...);
    }
};
int n,m;
int head[N],nex[N*2],ver[N*2],tot=0;
void add(int x,int y){
    ver[++tot]=y,nex[tot]=head[x],head[x]=tot;
}
int a[N];
int mx[N],S,root,size[N],fa[N],M[N];
int siz[20][N],id[20][N],top[20][N],num[20];
int dep[N];
bool vis[N];
int rt[N],idx=0;
int dist[20][N];
struct tree{
    int l,r,mx,tag;
}tr[N*70];
void build(int &p,int l,int r,int d){
    if(!p) p=++idx;
    tr[p].tag=0;
    if(l==r){
        tr[p].mx=dist[d][l];
        return;
    }
    int const mid=(l+r)>>1;
    build(tr[p].l,l,mid,d);
    build(tr[p].r,mid+1,r,d);
    tr[p].mx=max(tr[tr[p].l].mx,tr[tr[p].r].mx);
}
void pushdown(int p){
    if(tr[p].tag){
        tr[tr[p].l].tag+=tr[p].tag;
        tr[tr[p].l].mx+=tr[p].tag;
        tr[tr[p].r].tag+=tr[p].tag;
        tr[tr[p].r].mx+=tr[p].tag;
        tr[p].tag=0;
    }
}
int ask(int p,int ls,int rs,int l,int r){
    if(!p) return -inf;
    if(l>=ls &&r<=rs) return tr[p].mx;
    pushdown(p);
    int const mid=(l+r)>>1;
    int ans=-inf;
    if(ls<=mid) ans=max(ans,ask(tr[p].l,ls,rs,l,mid));
    if(rs>mid) ans=max(ans,ask(tr[p].r,ls,rs,mid+1,r));
    return ans;
}
void change(int p,int ls,int rs,int l,int r,int v){
    if(l>=ls && r<=rs){
        tr[p].mx+=v;
        tr[p].tag+=v;
        return;
    }
    pushdown(p);
    int const mid=(l+r)>>1;
    if(ls<=mid) change(tr[p].l,ls,rs,l,mid,v);
    if(rs>mid) change(tr[p].r,ls,rs,mid+1,r,v);
    tr[p].mx=max(tr[tr[p].l].mx,tr[tr[p].r].mx);
}
void find(int x,int fat){
    size[x]=1;
    mx[x]=0;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat || vis[y]) continue;
        find(y,x);
        size[x]+=size[y];
        mx[x]=max(mx[x],size[y]);
    }
    mx[x]=max(mx[x],S-size[x]);
    if(mx[x]<mx[root]) root=x;
}
void work(int x,int f,int tp,int op,int dis){
    siz[dep[op]][x]=1;
    id[dep[op]][x]=++num[dep[op]];
    top[dep[op]][x]=tp;
    dist[dep[op]][num[dep[op]]]=dis;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==f || vis[y]) continue;
        work(y,x,tp,op,dis+a[y]);
        siz[dep[op]][x]+=siz[dep[op]][y];
    }
}
multiset<int, std::greater<>> s[N];
DeletableHeap<int, std::less<>> ans;
inline int getans(int x){
    auto it=s[x].begin();
    int cnt=a[x];
    cnt+=max(*it,0ll);
    it++;
    cnt+=max(*it,0ll);
    return cnt;
}
int mxd=0;
void divide(int x,int d){
    mxd=max(mxd,d);
    dep[x]=d;
    s[x].insert(-(1ll<<60));
    s[x].insert(-(1ll<<60));
    vis[x]=1;
    id[d][x]=++num[d];
    dist[d][num[d]]=0;
    siz[d][x]=1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        work(y,x,y,x,a[y]);
        siz[dep[x]][x]+=siz[dep[x]][y];
    }
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        S=size[y];
        root=0;
        find(y,x);
        fa[root]=x;
        divide(root,d+1);
    }
}
void divide1(int x,int d){
    vis[x]=1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        int p=ask(rt[d],id[d][y],id[d][y]+siz[d][y]-1,1,num[d]);
        s[x].insert(p);
    }
    auto it=s[x].begin();
    int cnt=a[x];
    cnt+=max(*it,0ll);
    it++;
    cnt+=max(*it,0ll);
    ans.insert(cnt);
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        S=size[y];
        root=0;
        find(y,x);
        fa[root]=x;
        divide1(root,d+1);
    }
}

void solve(int op,int x,int d){
    int ans1=getans(op);
    ans.erase(ans1);
    int tp=top[dep[op]][x];
    if(!tp){
        ans1+=d;
        ans.insert(ans1);
        return;
    }
    int lastans=ask(rt[dep[op]],id[dep[op]][tp],id[dep[op]][tp]+siz[dep[op]][tp]-1,1,num[dep[op]]);
    s[op].erase(s[op].find(lastans));
    change(rt[dep[op]],id[dep[op]][x],id[dep[op]][x]+siz[dep[op]][x]-1,1,num[dep[op]],d);
    int ans2=ask(rt[dep[op]],id[dep[op]][tp],id[dep[op]][tp]+siz[dep[op]][tp]-1,1,num[dep[op]]);
    s[op].insert(ans2);
    int ans3=getans(op);
    ans.insert(ans3);
}
void updata(int x,int d){
    int tmp=x;
    while(tmp){
        solve(tmp,x,d);
        tmp=fa[tmp];
    }
}
inline int read(){
	int x(0);bool f(0);char ch=getchar();
	for(;ch<'0'||ch>'9';ch=getchar()) f^=ch=='-';
	for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+(ch^48);
	return f?x=-x:x;
}
inline void write(int x){
	x<0?x=-x,putchar('-'):0;static short Sta[50],top(0);
	do{Sta[++top]=x%10;x/=10;}while(x);
	while(top) putchar(Sta[top--]|48);
	putchar('\n');
}
signed main(){
    freopen("taffy.in","r",stdin);
    freopen("taffy.out","w",stdout);
    ans.insert(-(1ll<<60));
    ans.insert(-(1ll<<60));
    n=read(),m=read();
    for(int i=1;i<n;i++){
        int x,y;
        x=read(),y=read();
        add(x,y),add(y,x);
    }
    for(int i=1;i<=n;i++) a[i]=read();
    root=0;
    S=n;
    mx[0]=n+1;
    find(1,0);
    divide(root,1);
    for(int i=1;i<=mxd;i++) build(rt[i],1,num[i],i);
    memset(vis,0,sizeof(vis));
    root=0;
    S=n;
    find(1,0);
    divide1(root,1);
    for(int i=1;i<=m;i++){
        int op,u,x;
        op=read();
        if(op==1){
            auto it=ans.top();
            printf("%lld\n",it);
        }
        else if(op==2){
            u=read();
            updata(u,inf);
            a[u]+=inf;
            auto it=ans.top();
            int cnt=it-inf;
            write(cnt);
            updata(u,-inf);       
            a[u]-=inf;
        }
        else{
            u=read(),x=read();
            int d=x-a[u];
            updata(u,d);
            a[u]=x;
        }
    }
}

[ZJOI2015] 幻想乡战略游戏 #

考虑 u,v 是相邻两个点,计算答案的话如果 v 优于 u 那么就向 v 走。发现,首先计算答案不好计算,其次可能会走 n 次。所以考虑给它建一棵点分树,每次查询都从整棵树的跟开始走,如果一个儿子优于自己,那么就走向那个儿子所指向的重心,那么最多走 logn 次,在考虑计算答案,对于每个分治节点维护 sumdu:u 为分治重心分治层的 d 之和; sumf2uu 为分治重心分治层中 dxdis(x,u) 之和; sumf2:u 为分治重心分治层中 dxdis(fax,u) 之和(faxx 在分治树上的父亲)。然后发现每次求答案就是跳分治树,每次计算上一层除去自己这一层的答案,复杂度也是 logn。总复杂度 nlog2n

code
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10;
int head[N],nex[N*2],ver[N*2],edge[N*2],tot=0;
int n,Q;
void add(int x,int y,int w){
    ver[++tot]=y,nex[tot]=head[x],head[x]=tot,edge[tot]=w;
}
int g[N][40],dist[N],t,dep[N];
void dfs(int x,int fat){
    g[x][0]=fat;
    dep[x]=dep[fat]+1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat) continue;
        dist[y]=dist[x]+edge[i];
        dfs(y,x);
    }
}
void init(){
    t=log2(n)+1;
    for(int j=1;j<=t;j++)
        for(int i=1;i<=n;i++)
            g[i][j]=g[g[i][j-1]][j-1];
}
int ask_lca(int x,int y){
    if(dep[x]>dep[y]) swap(x,y);
    for(int i=t;i>=0;i--){
        if(dep[g[y][i]]>=dep[x]) y=g[y][i];
    }
    if(x==y) return x;
    for(int i=t;i>=0;i--){
        if(g[x][i]!=g[y][i]){
            x=g[x][i],y=g[y][i];
        }
    }
    return g[x][0];
}
int ask_dist(int x,int y){
    return dist[x]+dist[y]-2*dist[ask_lca(x,y)];
}
int root,mx[N],S,siz[N],fa[N];
bool vis[N];
unordered_map<int,int> rk[N];
void find(int x,int fat){
    mx[x]=1;
    siz[x]=1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(y==fat || vis[y]) continue;
        find(y,x);
        siz[x]+=siz[y];
        mx[x]=max(mx[x],siz[y]);
    }
    mx[x]=max(mx[x],S-siz[x]);
    if(mx[root]>mx[x]) root=x;
}
void solve(int x){
    vis[x]=1;
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(vis[y]) continue;
        root=0;
        S=siz[y];
        find(y,x);
        fa[root]=x;
        rk[x][y]=root;
        solve(root);
    }
}
int d[N],sumd[N],sumf1[N],sumf2[N];
void updata(int x,int v){
    int tmp=x;
    sumd[tmp]+=v;
    d[tmp]+=v;
    while(fa[tmp]){
        int dis=ask_dist(fa[tmp],x);
        sumf2[fa[tmp]]+=dis*v;
        sumf1[tmp]+=dis*v;
        sumd[fa[tmp]]+=v;
        tmp=fa[tmp];
    }
}
int calc(int x){
    int ans=sumf2[x];
    int tmp=x;
    while(fa[tmp]){
        int dis=ask_dist(fa[tmp],x);
        ans+=sumf2[fa[tmp]]-sumf1[tmp];
        ans+=dis*(sumd[fa[tmp]]-sumd[tmp]);
        tmp=fa[tmp];
    }
    return ans;
}
int query(int x){
    int ans=calc(x);
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        if(calc(y)<ans) return query(rk[x][y]);
    }
    return ans;
}

signed main(){
    scanf("%lld%lld",&n,&Q);
    for(int i=1;i<n;i++){
        int a,b,c;
        scanf("%lld%lld%lld",&a,&b,&c);
        add(a,b,c),add(b,a,c);
    }
    dfs(1,0);
    init();
    mx[0]=n+1;
    root=0;
    S=n;
    find(1,0);
    int gf=root;
    solve(root);
    for(int i=1;i<=Q;i++){
        int u,e;
        scanf("%lld%lld",&u,&e);
        updata(u,e);
        printf("%lld\n",query(gf));
    }
}

作者:bloss

出处:https://www.cnblogs.com/jinjiaqioi/p/17806380.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   _bloss  阅读(36)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
点击右上角即可分享
微信分享提示
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu