splay和lct
1. 查询前驱,后继,排名
splay基本操作
#include <cstdio> const int N = 1e6+10, INF = 0x3f3f3f3f; int tot, rt; struct { int cnt,sz,fa,ch[2],v; } tr[N]; void pu(int x) { tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt; } void rot(int x) { int y=tr[x].fa,z=tr[y].fa; int f=tr[y].ch[1]==x; tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z; tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y; tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x); } //s=0时将x旋转到根, 否则将x旋转到s的儿子 void splay(int x, int s=0) { for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) { rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x); } if (!s) rt=x; } //若splay中存在权值x,那么把x旋转到根,否则把x的前驱或后继旋转到根 //(get函数只是为了简化求前驱和后继操作, 其他地方不要使用get) void get(int x) { int cur=rt; while (x!=tr[cur].v&&tr[cur].ch[x>tr[cur].v]) cur=tr[cur].ch[x>tr[cur].v]; splay(cur); } //插入权值x void insert(int x) { int cur=rt,p=0; while (cur&&x!=tr[cur].v) p=cur,cur=tr[cur].ch[x>tr[cur].v]; if (cur) ++tr[cur].cnt; else { cur=++tot; if (p) tr[p].ch[x>tr[p].v]=cur,tr[cur].fa=p; tr[cur].v=x,tr[cur].sz=tr[cur].cnt=1; } splay(cur); } //返回<=x的节点编号 int pre(int x) { get(x); if (tr[rt].v<=x) return rt; int cur=tr[rt].ch[0]; while (tr[cur].ch[1]) cur=tr[cur].ch[1]; splay(cur); return cur; } //返回>=x的节点编号 int nxt(int x) { get(x); if (tr[rt].v>=x) return rt; int cur=tr[rt].ch[1]; while (tr[cur].ch[0]) cur=tr[cur].ch[0]; splay(cur); return cur; } //若权值x存在删除x对应节点, 否则无影响 void erase(int x) { int s1=pre(x-1),s2=nxt(x+1); splay(s1),splay(s2,s1); int &cur=tr[s2].ch[0]; if (tr[cur].cnt>1) --tr[cur].cnt,splay(cur); else cur=0; } //返回权值x的排名(<x的数的个数+1) int rk(int x) { int t = pre(x-1); return tr[tr[t].ch[0]].sz+tr[t].cnt; } //返回排名为k-1的数 int kth(int x, int k) { int s=tr[tr[x].ch[0]].sz; if (k<=s) return kth(tr[x].ch[0],k); if (k>s+tr[x].cnt) return kth(tr[x].ch[1],k-s-tr[x].cnt); return splay(x),x; } int main() { int n; scanf("%d", &n); //初始化插入INF和-INF insert(INF),insert(-INF); while (n--) { int op, x; scanf("%d%d", &op, &x); //插入 if (op==1) insert(x); //删除 else if (op==2) erase(x); //查询排名(比x小的数的个数+1) else if (op==3) printf("%d\n",rk(x)); //查询排名为x的数 else if (op==4) printf("%d\n",tr[kth(rt,x+1)].v); //求x的前驱(小于x且最大的数) else if (op==5) printf("%d\n",tr[pre(x-1)].v); //求x的后继(大于x且最小的数) else printf("%d\n",tr[nxt(x+1)].v); } }
2. 区间翻转
维护splay中序遍历得到的序列, 查询下标相当于查询排名
#include <cstdio> #include <algorithm> using namespace std; const int N = 1e6+10; int rt, tot; struct { int sz,v,ch[2],fa,rev; void upd() {swap(ch[0],ch[1]);rev^=1;} } tr[N]; void pu(int o) { tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1; } void pd(int o) { if (tr[o].rev) { tr[tr[o].ch[0]].upd(); tr[tr[o].ch[1]].upd(); tr[o].rev=0; } } void rot(int x) { int y=tr[x].fa,z=tr[y].fa; int f=tr[y].ch[1]==x,w=tr[x].ch[f^1]; tr[z].ch[tr[z].ch[1]==y]=x; tr[x].ch[f^1]=y,tr[y].ch[f]=w; tr[w].fa=y; tr[y].fa=x,tr[x].fa=z; pu(y),pu(x); } void splay(int x, int s=0) { for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) { rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x); } if (!s) rt=x; } int find(int x, int k) { pd(x); int s=tr[tr[x].ch[0]].sz; if (k==s+1) return splay(x),x; if (k<=s) return find(tr[x].ch[0],k); return find(tr[x].ch[1],k-s-1); } void reverse(int x, int y) { int s1=find(rt,x), s2=find(rt,y+2); splay(s1),splay(s2,s1); tr[tr[s2].ch[0]].upd(); } void build(int f, int &o, int l, int r) { if (l>r) return; o = ++tot; int mid = (l+r)/2; tr[o].v = mid-1, tr[o].fa = f; build(o,tr[o].ch[0],l,mid-1); build(o,tr[o].ch[1],mid+1,r); pu(o); } int n,m; void dfs(int x) { if (!x) return; pd(x); dfs(tr[x].ch[0]); if (1<=tr[x].v&&tr[x].v<=n) printf("%d ",tr[x].v); dfs(tr[x].ch[1]); } int main() { scanf("%d%d", &n, &m); build(0,rt,1,n+2); while (m--) { int x, y; scanf("%d%d", &x, &y); reverse(x,y); } dfs(rt),puts(""); }
3. 区间加区间求和
用splay提取区间, 转化为子树加和子树求和, 打标记即可
#include <iostream> using namespace std; const int N = 1e5+10, INF = 0x3f3f3f3f; int tot, rt, a[N]; struct { int sz,fa,ch[2]; long long sum, tag, v; void add(long long x) {sum+=sz*x;tag+=x;v+=x;} } tr[N]; void pu(int x) { tr[x].sz = tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+1; tr[x].sum = tr[tr[x].ch[0]].sum+tr[tr[x].ch[1]].sum+tr[x].v; } void pd(int x) { if (tr[x].tag) { tr[tr[x].ch[0]].add(tr[x].tag); tr[tr[x].ch[1]].add(tr[x].tag); tr[x].tag = 0; } } void rot(int x) { int y=tr[x].fa,z=tr[y].fa; int f=tr[y].ch[1]==x; tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z; tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y; tr[x].ch[f^1]=y,tr[y].fa=x,pu(y),pu(x); } void splay(int x, int s=0) { for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) { rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x); } if (!s) rt=x; } int find(int x, int k) { pd(x); int s = tr[tr[x].ch[0]].sz; if (k==s+1) return x; if (k<=s) return find(tr[x].ch[0],k); return find(tr[x].ch[1],k-s-1); } void build(int f, int &o, int l, int r) { if (l>r) return; o = ++tot; int mid = (l+r)/2; tr[o].v = a[mid], tr[o].fa = f; build(o,tr[o].ch[0],l,mid-1); build(o,tr[o].ch[1],mid+1,r); pu(o); } int split(int x, int y) { int s1 = find(rt,x), s2 = find(rt,y+2); splay(s1), splay(s2,s1); return tr[s2].ch[0]; } int main() { int n, m; scanf("%d%d", &n, &m); for (int i=2; i<=n+1; ++i) scanf("%d", &a[i]); build(0,rt,1,n+2); while (m--) { int op, x, y, z; scanf("%d%d%d", &op, &x, &y); if (op==1) { scanf("%d", &z); tr[split(x,y)].add(z); } else { printf("%lld\n", tr[split(x,y)].sum); } } }
用$lct$也可以做, 可以差分一下避免换根操作. 常数比splay略大
#include <bits/stdc++.h> using namespace std; const int N = 1e5+10; int a[N]; struct { int ch[2],fa,sz; long long sum, v, tag; void add(long long x) {sum+=sz*x,v+=x,tag+=x;} } tr[N]; void pd(int o) { if (tr[o].tag) { tr[tr[o].ch[0]].add(tr[o].tag); tr[tr[o].ch[1]].add(tr[o].tag); tr[o].tag = 0; } } void pu(int o) { tr[o].sz = tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1; tr[o].sum = tr[tr[o].ch[0]].sum+tr[tr[o].ch[1]].sum+tr[o].v; } int nroot(int x) { return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x; } void rot(int x) { int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k]; if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x; tr[x].ch[!k]=y,tr[y].ch[k]=w; if (w) tr[w].fa=y; tr[y].fa=x,tr[x].fa=z; pu(x),pu(y); } void repush(int x) { if (nroot(x)) repush(tr[x].fa); pd(x); } void splay(int x) { repush(x); while (nroot(x)) { int y=tr[x].fa,z=tr[y].fa; if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x); rot(x); } pu(x); } void access(int x) { for (int y=0,t=x; t; t=tr[y=t].fa) splay(t),tr[t].ch[1]=y,pu(t); splay(x); } int main() { int n, m; scanf("%d%d", &n, &m); for (int i=1; i<=n; ++i) { scanf("%d", &a[i]); tr[i].v = a[i]; if (i>1) tr[i].fa = i-1; } while (m--) { int op, x, y, z; scanf("%d%d%d", &op, &x, &y); if (op==1) { scanf("%d", &z); access(y); tr[y].add(z); if (x>1) access(x-1), tr[x-1].add(-z); } else { access(y); long long ans = tr[y].sum; if (x>1) access(x-1), ans -= tr[x-1].sum; printf("%lld\n", ans); } } }
4. 动态添边维护最小生成树
$n\le 10^3$的话, 可以先$prim$求出最小生成树, 添一条边$(u,v)$, 那么就暴力求出当前最小生成树中路径$(u,v)$上的最大边, 如果添的边更小就替换掉, 复杂度是$O(n(n+q))$
#include <bits/stdc++.h> using namespace std; const int N = 1e3+10, M = 1e5+10, INF = 0x3f3f3f3f; int n,m,q,ret,a[N][N],dis[N],pos[N],vis[N]; int ans[M],u,v,w; struct {int op,u,v,w;} e[M]; vector<int> g[N]; int dfs(int x, int f, int s) { if (x==s) return 1; for (int y:g[x]) if (y!=f) { if (dfs(y,x,s)) { if (a[x][y]>w) { w = a[x][y]; u = x, v = y; } ret=max(ret,a[y][x]); return 1; } } return 0; } int main() { scanf("%d%d%d", &n, &m, &q); memset(a,0x3f,sizeof a); for (int i=1; i<=m; ++i) { int u, v, w; scanf("%d%d%d", &u, &v, &w); a[u][v] = a[v][u] = w; } for (int i=1; i<=q; ++i) { scanf("%d%d%d", &e[i].op, &e[i].u, &e[i].v); if (e[i].op==2) { e[i].w = a[e[i].u][e[i].v]; a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = INF; } } memset(dis,0x3f,sizeof dis); dis[1] = 0; for (int i=1; i<=n; ++i) { int mi = INF, x, y; for (int j=1; j<=n; ++j) { if (!vis[j]&&dis[j]<mi) { mi = dis[j]; x = j, y = pos[j]; } } vis[x] = 1; if (x!=1) { g[x].push_back(y); g[y].push_back(x); } for (int j=1; j<=n; ++j) { if (dis[j]>a[x][j]) dis[j] = a[x][j], pos[j] = x; } } for (int i=q; i; --i) { w = 0; if (e[i].op==1) { dfs(e[i].u,0,e[i].v); ans[i] = w; } else { dfs(e[i].u,0,e[i].v); a[e[i].u][e[i].v] = a[e[i].v][e[i].u] = e[i].w; auto del = [&](int u, int v) { for (int i=0; i<g[u].size(); ++i) { if (g[u][i]==v) { swap(g[u][i],g[u].back()); g[u].pop_back(); return; } } }; if (e[i].w<w) { del(u,v); del(v,u); g[e[i].u].push_back(e[i].v); g[e[i].v].push_back(e[i].u); } } } for (int i=1; i<=q; ++i) if (e[i].op==1) printf("%d\n", ans[i]); }
用lct的话等价于要找边权最大值以及两端点, $lct$维护边权可以把每条边新建一个点, 转化为维护点权
复杂度是$O((m+q)\log{m})$
#include <bits/stdc++.h> using namespace std; const int N = 2e5+10; int n,m,q,fa[N],ans[N],vis[N],val[N]; struct edge {int u,v,w;} e[N]; struct Q {int op,u,v,id;} f[N]; map<pair<int,int>,int> mp; int Find(int x) {return fa[x]?fa[x]=Find(fa[x]):x;} struct { int ch[2],fa,tag,v; void rev() {tag^=1;swap(ch[0],ch[1]);} } tr[N]; void pd(int o) { if (tr[o].tag) { tr[tr[o].ch[0]].rev(); tr[tr[o].ch[1]].rev(); tr[o].tag = 0; } } void pu(int o) { tr[o].v = max({val[o], tr[tr[o].ch[0]].v, tr[tr[o].ch[1]].v}); } int nroot(int x) { return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x; } void rot(int x) { int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[!k]; if (nroot(y)) tr[z].ch[tr[z].ch[1]==y]=x; tr[x].ch[!k]=y,tr[y].ch[k]=w; if (w) tr[w].fa=y; tr[y].fa=x,tr[x].fa=z; pu(x),pu(y); } void repush(int x) { if (nroot(x)) repush(tr[x].fa); pd(x); } void splay(int x) { repush(x); while (nroot(x)) { int y=tr[x].fa,z=tr[y].fa; if (nroot(y)) rot((tr[y].ch[0]==x)==(tr[z].ch[0]==y)?y:x); rot(x); } pu(x); } void access(int x) { for (int y=0; x; x=tr[y=x].fa) splay(x),tr[x].ch[1]=y,pu(x); } void makeroot(int x) { access(x),splay(x); tr[x].rev(); } int findroot(int x) { access(x),splay(x); while (tr[x].ch[0]) pd(x),x=tr[x].ch[0]; return splay(x), x; } void split(int x, int y) { makeroot(x),access(y),splay(y); } void link(int x, int y) { makeroot(x); tr[x].fa=y; } void cut(int x, int y) { split(x,y); tr[y].ch[0]=tr[x].fa=0; pu(y); } int main() { scanf("%d%d%d", &n, &m, &q); for (int i=1; i<=m; ++i) { scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w); if (e[i].u>e[i].v) swap(e[i].u, e[i].v); } for (int i=1; i<=q; ++i) { scanf("%d%d%d", &f[i].op, &f[i].u, &f[i].v); if (f[i].u>f[i].v) swap(f[i].u,f[i].v); } sort(e+1,e+1+m,[](edge a,edge b){return a.w<b.w;}); for (int i=1; i<=m; ++i) { val[i+n] = mp[{e[i].u,e[i].v}] = i; } for (int i=1; i<=q; ++i) { if (f[i].op==2) { f[i].id = mp[{f[i].u,f[i].v}]; vis[f[i].id] = 1; } } for (int i=1; i<=m; ++i) if (!vis[i]) { int u = Find(e[i].u), v = Find(e[i].v); if (u!=v) { link(e[i].u,i+n); link(e[i].v,i+n); fa[u] = v; } } for (int i=q; i; --i) { if (f[i].op==1) { split(f[i].u,f[i].v); ans[i] = e[tr[f[i].v].v].w; } else { split(f[i].u,f[i].v); int x = tr[f[i].v].v, y = mp[{f[i].u,f[i].v}]; if (x>y) { cut(e[x].u,x+n); cut(e[x].v,x+n); link(e[y].u,y+n); link(e[y].v,y+n); } } } for (int i=1; i<=q; ++i) if (f[i].op==1) printf("%d\n", ans[i]); }