splay和lct

1. 查询前驱,后继,排名

splay基本操作

#include <cstdio>
const int N = 1e6+10, INF = 0x3f3f3f3f;
int tot, rt;
struct {
    int cnt,sz,fa,ch[2],v;
} tr[N];
void pu(int x) {
    tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt;
}
void rot(int x) {
    int y=tr[x].fa,z=tr[y].fa;
    int f=tr[y].ch[1]==x;
    tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
    tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
    tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x);
}
//s=0时将x旋转到根, 否则将x旋转到s的儿子
void splay(int x, int s=0) {
    for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
        rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
    }
    if (!s) rt=x;
}
//若splay中存在权值x,那么把x旋转到根,否则把x的前驱或后继旋转到根
//(get函数只是为了简化求前驱和后继操作, 其他地方不要使用get)
void get(int x) {
    int cur=rt;
    while (x!=tr[cur].v&&tr[cur].ch[x>tr[cur].v]) cur=tr[cur].ch[x>tr[cur].v];
    splay(cur);
}
//插入权值x
void insert(int x) {
    int cur=rt,p=0;
    while (cur&&x!=tr[cur].v) p=cur,cur=tr[cur].ch[x>tr[cur].v];
    if (cur) ++tr[cur].cnt;
    else {
        cur=++tot;
        if (p) tr[p].ch[x>tr[p].v]=cur,tr[cur].fa=p;
        tr[cur].v=x,tr[cur].sz=tr[cur].cnt=1;
    }
    splay(cur);
}
//返回<=x的节点编号
int pre(int x) {
    get(x);
    if (tr[rt].v<=x) return rt;
    int cur=tr[rt].ch[0];
    while (tr[cur].ch[1]) cur=tr[cur].ch[1];
    splay(cur);
    return cur;
}
//返回>=x的节点编号
int nxt(int x) {
    get(x);
    if (tr[rt].v>=x) return rt;
    int cur=tr[rt].ch[1];
    while (tr[cur].ch[0]) cur=tr[cur].ch[0];
    splay(cur);
    return cur;
}
//若权值x存在删除x对应节点, 否则无影响
void erase(int x) {
    int s1=pre(x-1),s2=nxt(x+1);
    splay(s1),splay(s2,s1);
    int &cur=tr[s2].ch[0];
    if (tr[cur].cnt>1) --tr[cur].cnt,splay(cur);
    else cur=0;
}
//返回权值x的排名(<x的数的个数+1)
int rk(int x) {
    int t = pre(x-1);
    return tr[tr[t].ch[0]].sz+tr[t].cnt;
}
//返回排名为k-1的数
int kth(int x, int k) {
    int s=tr[tr[x].ch[0]].sz;
    if (k<=s) return kth(tr[x].ch[0],k);
    if (k>s+tr[x].cnt) return kth(tr[x].ch[1],k-s-tr[x].cnt);
    return splay(x),x;
}
int main() {
    int n;
    scanf("%d", &n);
    //初始化插入INF和-INF
    insert(INF),insert(-INF);
    while (n--) {
        int op, x;
        scanf("%d%d", &op, &x);
        //插入
        if (op==1) insert(x);
        //删除
        else if (op==2) erase(x);
        //查询排名(比x小的数的个数+1)
        else if (op==3) printf("%d\n",rk(x));
        //查询排名为x的数
        else if (op==4) printf("%d\n",tr[kth(rt,x+1)].v);
        //求x的前驱(小于x且最大的数)
        else if (op==5) printf("%d\n",tr[pre(x-1)].v);
        //求x的后继(大于x且最小的数)
        else printf("%d\n",tr[nxt(x+1)].v);
    }
}
P3369

2. 区间翻转

维护splay中序遍历得到的序列, 查询下标相当于查询排名

#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1e6+10;
int rt, tot;
struct {
    int sz,v,ch[2],fa,rev;
    void upd() {swap(ch[0],ch[1]);rev^=1;}
} tr[N];
void pu(int o) {
    tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
}
void pd(int o) {
    if (tr[o].rev) {
        tr[tr[o].ch[0]].upd();
        tr[tr[o].ch[1]].upd();
        tr[o].rev=0;
    }
}
void rot(int x) {
    int y=tr[x].fa,z=tr[y].fa;
    int f=tr[y].ch[1]==x,w=tr[x].ch[f^1];
    tr[z].ch[tr[z].ch[1]==y]=x;
    tr[x].ch[f^1]=y,tr[y].ch[f]=w;
    tr[w].fa=y;
    tr[y].fa=x,tr[x].fa=z;
    pu(y),pu(x);
}
void splay(int x, int s=0) {
    for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
        rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
    }
    if (!s) rt=x;
}
int find(int x, int k) {
    pd(x); int s=tr[tr[x].ch[0]].sz;
    if (k==s+1) return splay(x),x;
    if (k<=s) return find(tr[x].ch[0],k);
    return find(tr[x].ch[1],k-s-1);
}
void reverse(int x, int y) {
    int s1=find(rt,x), s2=find(rt,y+2);
    splay(s1),splay(s2,s1);
    tr[tr[s2].ch[0]].upd();
}
void build(int f, int &o, int l, int r) {
    if (l>r) return;
    o = ++tot; int mid = (l+r)/2;
    tr[o].v = mid-1, tr[o].fa = f;
    build(o,tr[o].ch[0],l,mid-1);
    build(o,tr[o].ch[1],mid+1,r);
    pu(o);
}
int n,m;
void dfs(int x) {
    if (!x) return;
    pd(x);
    dfs(tr[x].ch[0]);
    if (1<=tr[x].v&&tr[x].v<=n) printf("%d ",tr[x].v);
    dfs(tr[x].ch[1]);
}
int main() {
    scanf("%d%d", &n, &m);
    build(0,rt,1,n+2);
    while (m--) {
        int x, y;
        scanf("%d%d", &x, &y);
        reverse(x,y);
    }
    dfs(rt),puts("");
}
P3391

3. 区间加区间求和

用splay提取区间, 转化为子树加和子树求和, 打标记即可

#include <iostream>
using namespace std;
const int N = 1e5+10, INF = 0x3f3f3f3f;
int tot, rt, a[N];
struct {
    int sz,fa,ch[2];
    long long sum, tag, v;
    void add(long long x) {sum+=sz*x;tag+=x;v+=x;}
} tr[N];
void pu(int x) {
    tr[x].sz = tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+1;
    tr[x].sum = tr[tr[x].ch[0]].sum+tr[tr[x].ch[1]].sum+tr[x].v;
}
void pd(int x) {
    if (tr[x].tag) {
        tr[tr[x].ch[0]].add(tr[x].tag);
        tr[tr[x].ch[1]].add(tr[x].tag);
        tr[x].tag = 0;
    }
}
void rot(int x) {
    int y=tr[x].fa,z=tr[y].fa;
    int f=tr[y].ch[1]==x;
    tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
    tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
    tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x);
}
void splay(int x, int s=0) {
    for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
        rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
    }
    if (!s) rt=x;
}
int find(int x, int k) {
    pd(x); int s = tr[tr[x].ch[0]].sz;
    if (k==s+1) return x;
    if (k<=s) return find(tr[x].ch[0],k);
    return find(tr[x].ch[1],k-s-1);
}
void build(int f, int &o, int l, int r) {
    if (l>r) return;
    o = ++tot; int mid = (l+r)/2;
    tr[o].v = a[mid], tr[o].fa = f;
    build(o,tr[o].ch[0],l,mid-1);
    build(o,tr[o].ch[1],mid+1,r);
    pu(o);
}
int split(int x, int y) {
    int s1 = find(rt,x), s2 = find(rt,y+2);
    splay(s1), splay(s2,s1);
    return tr[s2].ch[0];
}
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i=2; i<=n+1; ++i) scanf("%d", &a[i]);
    build(0,rt,1,n+2);
    while (m--) {
        int op, x, y, z;
        scanf("%d%d%d", &op, &x, &y);
        if (op==1) {
            scanf("%d", &z);
            tr[split(x,y)].add(z);
        }
        else {
            printf("%lld\n", tr[split(x,y)].sum);
        }
    }
}
P3372

用$lct$也可以做, 可以差分一下避免换根操作. 常数比splay略大

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int a[N];
struct {
    int ch[2],fa,sz;
    long long sum, v, tag;
    void add(long long x) {sum+=sz*x,v+=x,tag+=x;}
} tr[N];
void pd(int o) {
    if (tr[o].tag) {
        tr[tr[o].ch[0]].add(tr[o].tag);
        tr[tr[o].ch[1]].add(tr[o].tag);
        tr[o].tag = 0;
    }
}
void pu(int o) {
    tr[o].sz = tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
    tr[o].sum = tr[tr[o].ch[0]].sum+tr[tr[o].ch[1]].sum+tr[o].v;
}
int nroot(int x) {
    return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x;
}
void rot(int x) {
    int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k];
    if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x;
    tr[x].ch[!k]=y,tr[y].ch[k]=w;
    if (w) tr[w].fa=y;
    tr[y].fa=x,tr[x].fa=z;
    pu(x),pu(y);
}
void repush(int x) {
    if (nroot(x)) repush(tr[x].fa);
    pd(x);
}
void splay(int x) {
    repush(x);
    while (nroot(x)) {
        int y=tr[x].fa,z=tr[y].fa;
        if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x);
        rot(x);
    }
    pu(x);
}
void access(int x) {
    for (int y=0,t=x; t; t=tr[y=t].fa) splay(t),tr[t].ch[1]=y,pu(t);
    splay(x);
}
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i=1; i<=n; ++i) { 
        scanf("%d", &a[i]);
        tr[i].v = a[i];
        if (i>1) tr[i].fa = i-1;
    }
    while (m--) {
        int op, x, y, z;
        scanf("%d%d%d", &op, &x, &y);
        if (op==1) {
            scanf("%d", &z);
            access(y);
            tr[y].add(z);
            if (x>1) access(x-1), tr[x-1].add(-z);
        }
        else {
            access(y);
            long long ans = tr[y].sum;
            if (x>1) access(x-1), ans -= tr[x-1].sum;
            printf("%lld\n", ans);
        }
    }
}
View Code

4. 动态添边维护最小生成树

$n\le 10^3$的话, 可以先$prim$求出最小生成树, 添一条边$(u,v)$, 那么就暴力求出当前最小生成树中路径$(u,v)$上的最大边, 如果添的边更小就替换掉, 复杂度是$O(n(n+q))$

#include <bits/stdc++.h>
using namespace std;
const int N = 1e3+10, M = 1e5+10, INF = 0x3f3f3f3f;
int n,m,q,ret,a[N][N],dis[N],pos[N],vis[N];
int ans[M],u,v,w;
struct {int op,u,v,w;} e[M];
vector<int> g[N];
int dfs(int x, int f, int s) {
    if (x==s) return 1;
    for (int y:g[x]) if (y!=f) { 
        if (dfs(y,x,s)) { 
            if (a[x][y]>w) {
                w = a[x][y];
                u = x, v = y;
            }
            ret=max(ret,a[y][x]);
            return 1;
        }
    }
    return 0;
}
int main() {
    scanf("%d%d%d", &n, &m, &q);
    memset(a,0x3f,sizeof a);
    for (int i=1; i<=m; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        a[u][v] = a[v][u] = w;
    }
    for (int i=1; i<=q; ++i) {
        scanf("%d%d%d", &e[i].op, &e[i].u, &e[i].v);
        if (e[i].op==2) { 
            e[i].w = a[e[i].u][e[i].v];
            a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = INF;
        }
    }
    memset(dis,0x3f,sizeof dis);
    dis[1] = 0;
    for (int i=1; i<=n; ++i) {
        int mi = INF, x, y;
        for (int j=1; j<=n; ++j) {
            if (!vis[j]&&dis[j]<mi) {
                mi = dis[j];
                x = j, y = pos[j];
            }
        }
        vis[x] = 1;
        if (x!=1) { 
            g[x].push_back(y);
            g[y].push_back(x);
        }
        for (int j=1; j<=n; ++j) {
            if (dis[j]>a[x][j]) dis[j] = a[x][j], pos[j] = x;
        }
    }
    for (int i=q; i; --i) {
        w = 0;
        if (e[i].op==1) { 
            dfs(e[i].u,0,e[i].v);
            ans[i] = w;
        }
        else {
            dfs(e[i].u,0,e[i].v);
            a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = e[i].w;
            auto del = [&](int u, int v) {
                for (int i=0; i<g[u].size(); ++i) {
                    if (g[u][i]==v) {
                        swap(g[u][i],g[u].back());
                        g[u].pop_back();
                        return;
                    }
                }
            };
            if (e[i].w<w) {
                del(u,v);
                del(v,u);
                g[e[i].u].push_back(e[i].v);
                g[e[i].v].push_back(e[i].u);
            }
        }
    }
    for (int i=1; i<=q; ++i) if (e[i].op==1) printf("%d\n", ans[i]);
}
View Code

用lct的话等价于要找边权最大值以及两端点, $lct$维护边权可以把每条边新建一个点, 转化为维护点权

复杂度是$O((m+q)\log{m})$ 

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5+10;
int n,m,q,fa[N],ans[N],vis[N],val[N];
struct edge {int u,v,w;} e[N];
struct Q {int op,u,v,id;} f[N];
map<pair<int,int>,int> mp;
int Find(int x) {return fa[x]?fa[x]=Find(fa[x]):x;}
struct {
    int ch[2],fa,tag,v;
    void rev() {tag^=1;swap(ch[0],ch[1]);}
} tr[N];
void pd(int o) {
    if (tr[o].tag) {
        tr[tr[o].ch[0]].rev();
        tr[tr[o].ch[1]].rev();
        tr[o].tag = 0;
    }
}
void pu(int o) {
    tr[o].v = max({val[o], tr[tr[o].ch[0]].v, tr[tr[o].ch[1]].v});
}
int nroot(int x) {
    return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x;
}
void rot(int x) {
    int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k];
    if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x;
    tr[x].ch[!k]=y,tr[y].ch[k]=w;
    if (w) tr[w].fa=y;
    tr[y].fa=x,tr[x].fa=z;
    pu(x),pu(y);
}
void repush(int x) {
    if (nroot(x)) repush(tr[x].fa);
    pd(x);
}
void splay(int x) {
    repush(x);
    while (nroot(x)) {
        int y=tr[x].fa,z=tr[y].fa;
        if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x);
        rot(x);
    }
    pu(x);
}
void access(int x) {
    for (int y=0; x; x=tr[y=x].fa) splay(x),tr[x].ch[1]=y,pu(x);
}
void makeroot(int x) {
    access(x),splay(x);
    tr[x].rev();
}
int findroot(int x) {
    access(x),splay(x);
    while (tr[x].ch[0]) pd(x),x=tr[x].ch[0];
    return splay(x), x;
}
void split(int x, int y) {
    makeroot(x),access(y),splay(y);
}
void link(int x, int y) {
    makeroot(x);
    tr[x].fa=y;
}
void cut(int x, int y) {
    split(x,y);
    tr[y].ch[0]=tr[x].fa=0;
    pu(y);
}
int main() {
    scanf("%d%d%d", &n, &m, &q);
    for (int i=1; i<=m; ++i) { 
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
        if (e[i].u>e[i].v) swap(e[i].u, e[i].v);
    }
    for (int i=1; i<=q; ++i) {
        scanf("%d%d%d", &f[i].op, &f[i].u, &f[i].v);
        if (f[i].u>f[i].v) swap(f[i].u,f[i].v);
    }
    sort(e+1,e+1+m,[](edge a,edge b){return a.w<b.w;});
    for (int i=1; i<=m; ++i) { 
        val[i+n] = mp[{e[i].u,e[i].v}] = i;
    }
    for (int i=1; i<=q; ++i) {
        if (f[i].op==2) { 
            f[i].id = mp[{f[i].u,f[i].v}];
            vis[f[i].id] = 1;
        }
    }
    for (int i=1; i<=m; ++i) if (!vis[i]) {
        int u = Find(e[i].u), v = Find(e[i].v);
        if (u!=v) {
            link(e[i].u,i+n);
            link(e[i].v,i+n);
            fa[u] = v;
        }
    }
    for (int i=q; i; --i) {
        if (f[i].op==1) {
            split(f[i].u,f[i].v);
            ans[i] = e[tr[f[i].v].v].w;
        }
        else { 
            split(f[i].u,f[i].v);
            int x = tr[f[i].v].v, y = mp[{f[i].u,f[i].v}];
            if (x>y) {
                cut(e[x].u,x+n);
                cut(e[x].v,x+n);
                link(e[y].u,y+n);
                link(e[y].v,y+n);
            }
        }
    }
    for (int i=1; i<=q; ++i) if (f[i].op==1) printf("%d\n", ans[i]);
}
View Code

 

posted @ 2020-10-10 15:16  dz8gk0j  阅读(161)  评论(0编辑  收藏  举报