OI常见模板汇总

单源最短路:

堆优化的Dijkstra:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 1e5 + 5;

int n, m, S, num_edge = 0;
int dis[maxn], vis[maxn], head[maxn];

priority_queue< pair<int, int> > q;

struct node {
    int to, nxt, dis;
}edge[maxn << 1];

void origin() {
    memset(dis, inf, sizeof(dis));
    memset(head, -1, sizeof(head));
} 

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x; 
}

void addedge(int from, int to, int dis) {
    edge[++num_edge].nxt = head[from];
    edge[num_edge].to = to;
    edge[num_edge].dis = dis;
    head[from] = num_edge;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

void dijkstra(int S) {
    dis[S] = 0;
    q.push(make_pair(0, S));
    while (!q.empty()) {
        int u = q.top().second;
        q.pop();
        if (vis[u]) continue;
        vis[u] = 1;
        for (register int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (dis[v] > dis[u] + edge[i].dis) {
                dis[v] = dis[u] + edge[i].dis;
                q.push(make_pair(-dis[v], v));
            }
        }
    }
}

int main() {
    origin();
    n = read(), m = read(), S = read();
    rep(i, 1, m) {
        int u, v, w;
        u = read(), v = read(), w = read();
        addedge(u, v, w);
    }
    dijkstra(S);
    rep(i, 1, n) {
        write(dis[i]);
        printf(" ");
    }
}
View Code

 

队列优化的Bellman-Ford:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 1e5 + 5;

int n, m, S, num_edge = 0;
int dis[maxn], vis[maxn], head[maxn];

queue<int> q;

struct node {
    int to, nxt, dis;
}edge[maxn << 1];

void origin() {
    memset(dis, inf, sizeof(dis));
    memset(head, -1, sizeof(head));
}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void addedge(int from, int to, int dis) {
    edge[++num_edge].nxt = head[from];
    edge[num_edge].to = to;
    edge[num_edge].dis = dis;
    head[from] = num_edge;
}

void Bellman_Ford(int S) {
    dis[S] = 0, vis[S] = 1;
    q.push(S);
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        vis[u] = 0;
        for (register int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (dis[v] > dis[u] + edge[i].dis) {
                dis[v] = dis[u] + edge[i].dis;
                if (!vis[v]) {
                    q.push(v);
                    vis[v] = 1;
                }
            }
        }
    }
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    origin();
    n = read(), m = read(), S = read();
    rep(i, 1, m) {
        int u, v, w;
        u = read(), v = read(), w = read();
        addedge(u, v, w);
    }
    Bellman_Ford(S);
    rep(i, 1, n) {
        write(dis[i]);
        printf(" ");
    }
    return 0;
}
View Code

 

任意两点间最短路径(Floyd):

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 1e3 + 5;

int n, m;
int dis[maxn][maxn];

int MIN(int a, int b) {return a < b ? a : b;}

void origin() {memset(dis, inf, sizeof(dis));}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    origin();
    n = read(), m = read();
     rep(i, 1, m) {
         int u, v, w;
         u = read(), v = read(), w = read();
         dis[u][v] = MIN(dis[u][v], w);
    }
    rep(k, 1, n)
        rep(i, 1, n)
            rep(j, 1, n)
                 dis[i][j] = MIN(dis[i][j], dis[i][k] + dis[k][j]);
    rep(i, 1, n)
        rep(j, 1, n) {
            write(dis[i][j]);
            printf(" ");
        }      
    return 0;
}
View Code

 

最小生成树

Kruskal:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int maxn = 5e3 + 5, maxm = 2e5 + 5;

int n, m, ans = 0;
int fa[maxn];

struct rec{
    int to, dis, from;
}edge[maxm];

bool operator <(const rec &a, const rec &b) {
    return a.dis < b.dis;
}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    n = read(), m = read();
    rep(i, 1, m) edge[i].from = read(), edge[i].to = read(), edge[i].dis = read();
    sort(edge + 1, edge + m + 1);
    rep(i, 1, n) fa[i] = i;
    rep(i, 1, m) {
        int x = find(edge[i].from), y = find(edge[i].to);
        if (x == y) continue;
        fa[x] = y;
        ans += edge[i].dis;
    }
    write(ans);
    return 0;
}
View Code

 

堆优化的Prim :

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 5e3 + 5, maxm = 4e5 + 5;

int n, m, ans = 0, num_edge = 0;
int dis[maxn], vis[maxn], head[maxn];

priority_queue< pair<int, int> > q;

struct node {
    int to, nxt, dis;
}edge[maxm];

void origin() {
    memset(dis, inf, sizeof(dis));
    memset(head, -1, sizeof(head));
}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void addedge(int from, int to, int dis) {
    edge[++num_edge].nxt = head[from];
    edge[num_edge].to = to;
    edge[num_edge].dis = dis;
    head[from] = num_edge;
}

void prim() {
    dis[1] = 0;
    q.push(make_pair(0, 1));
    while (!q.empty()) {
        int u = q.top().second;
        q.pop();
        if (vis[u]) continue;
        vis[u] = 1;
        for (register int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (!vis[v] && dis[v] > edge[i].dis) {
                dis[v] = edge[i].dis;
                q.push(make_pair(-dis[v], v));
            }
        }    
    }
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    origin();
    n = read(), m = read();
    rep(i, 1, m) {
        int u, v, w;
        u = read(), v = read(), w = read();
        addedge(u, v, w);
        addedge(v, u, w);
    }
    prim();
    rep(i, 2, n) ans += dis[i];
    write(ans);
    return 0;
}
View Code

 

最近公共祖先(LCA):

树上倍增法:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)
#define per(i, a, b) for (register int i = (a); i >= (b); i--)

const int maxn = 5e5 + 5;

int n, m, t, S, num_edge = 0;
int dep[maxn], head[maxn], fa[maxn][20];

queue<int> q;

struct node {
    int to, nxt;
}edge[maxn << 1];

void origin() {memset(head, -1, sizeof(head));}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void addedge(int from, int to) {
    edge[++num_edge].nxt = head[from];
    edge[num_edge].to = to;
    head[from] = num_edge;
}

void bfs(int S) {
    q.push(S);
    dep[S] = 1;
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        for (register int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (dep[v]) continue;
            dep[v] = dep[u] + 1;
            fa[v][0] = u;
            rep(j, 1, t) fa[v][j] = fa[fa[v][j - 1]][j - 1];
            q.push(v);
        }
    }
}

int lca(int u, int v) {
    if (dep[u] > dep[v]) u ^= v ^= u ^= v;
    per(i, t, 0)
        if (dep[fa[v][i]] >= dep[u]) v = fa[v][i];
    if (u == v) return u;
    per(i, t, 0) {
        if (fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    origin();
    n = read(), m = read(), S = read();
    t = (int)(log(n) / log(2)) + 1;
    rep(i, 1, n - 1) {
        int u, v;
        u = read(), v = read();
        addedge(u, v);
        addedge(v, u);
    }
    bfs(S);
    rep(i, 1, m) {
        int u, v;
        u = read(), v = read();
        write(lca(u, v));
        printf("\n");
    } 
    return 0;
}
View Code

 

高精度运算:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = a; i <= b; i++)
#define per(i, a, b) for (register int i = a; i >= b; i--)

const int power = 4, base = 1e4, maxn = 1e4 + 5;

char a[maxn], b[maxn];

struct num {
    int a[maxn];
    num() {memset(a, 0, sizeof(a));}
    int& operator [](int x) {return a[x];}
    num(char *s) {
        memset(a, 0, sizeof(a));
        int len = strlen(s);
        a[0] = (len + power - 1) / power;
        for (register int w, i = 0, t = 0; i < len; w = (w << 1) + (w << 3), i++) {
            if (!(i % power)) {
                w = 1;
                t++;
            }
            a[t] += w * (s[i] - '0');
        }
    }
    void add(int k) {if (k || a[0]) a[++a[0]] = k;}
    void rever() {reverse(a + 1, a + a[0] + 1);}
    void write() {
        printf("%d", a[a[0]]);
        per(i, a[0] - 1, 1)
            printf("%0*d", power, a[i]);
    }
}p, q, ans;  

int operator <(num &p, num &q) {
    if (p[0] < q[0]) return 1;
    if (p[0] > p[0]) return 0;
    for (int i = p[0]; i > 0; i--)
        if (p[i] != q[i])
            return p[i] < q[i];
    return 0;
}  

num operator +(num &p, num &q) {
    num a;
    a[0] = max(p[0], q[0]);
    rep (i, 1, a[0]) {
        a[i] += p[i] + q[i];
        a[i + 1] = a[i] / base;
        a[i] %= base;
    }
    while (a[a[0] + 1]) a[0]++;
    return a;
}

num operator -(num &p, num &q) {
    num a = p;
    rep(i, 1, a[0]) {
        a[i] -= q[i];
        if (a[i] < 0) {
            a[i] += base;
            a[i + 1]--;
        }
    }
    while (a[0] && !a[a[0]]) a[0]--;
    return a;
}

num operator *(num &p, num &q) {
    num a;
    a[0] = p[0] + q[0] - 1;
    rep(i, 1, p[0])
        rep(j, 1, q[0]) {
            a[i + j - 1] += p[i] * q[j];
            a[i + j] += a[i + j - 1] / base;
            a[i + j - 1] %= base;
        }
    if (a[a[0] + 1]) a[0]++;
    return a;
}

num operator /(num &p, num &q) {
    num a, b;
    per(i, p[0], 1) {
        b.add(p[i]);
        b.rever();
        while (!(b < q)) {
            b = b - q;
            a[i]++;
        }
        b.rever();
    }
    a[0] = p[0];
    while (a[0] && !a[a[0]]) a[0]--;
    return a;
}  

int main() {
    scanf("%s", a), scanf("%s", b);
    reverse(a, a + strlen(a));
    reverse(b, b + strlen(b));
    p = num(a), q = num(b);
    ans = p + q;
    ans.write();
    if (p < q) {
        putchar('-');
        ans = q - p;
        ans.write();
    }
    else {
        ans = p - q;
        ans.write();
    }
    ans = p * q;
    ans.write();
    ans = p / q;
    ans.write();
    return 0;
}
View Code

 

线段树:

#include <bits/stdc++.h>
using namespace std;
#define ls rt << 1
#define rs rt << 1 | 1
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int maxn = 1e5 + 5;

int n, m;
int a[maxn];

struct segment_tree {
    int l, r, sum, lazy;
    #define l(x) T[x].l
    #define r(x) T[x].r
    #define sum(x) T[x].sum
    #define lazy(x) T[x].lazy
    #define length(x) (r(x) - l(x) + 1)
}T[maxn << 2];

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void build(int rt, int l, int r) {
    l(rt) = l, r(rt) = r;
    if (l == r) {
        sum(rt) = a[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    sum(rt) = sum(ls) + sum(rs);
}

void spread(int rt) {
    if (lazy(rt)) {
        sum(ls) += lazy(rt) * length(ls);
        sum(rs) += lazy(rt) * length(rs);
        lazy(ls) += lazy(rt);
        lazy(rs) += lazy(rt);
        lazy(rt) = 0;
    }
}

void change(int rt, int l ,int r, int k) {
    if (l <= l(rt) && r(rt) <= r) {
        sum(rt) += k * length(rt);
        lazy(rt) += k;
        return;
    }
    int mid = (l(rt) + r(rt)) >> 1;
    if (l <= mid) change(ls, l, r, k);
    if (r > mid) change(rs, l, r, k);
    spread(rt);
    sum(rt) = sum(ls) + sum(rs);
}

int query(int rt, int l, int r) {
    if (l <= l(rt) && r(rt) <= r) return sum(rt);
    spread(rt);
    int mid = (l(rt) + r(rt)) >> 1;
    int ans = 0;
    if (l <= mid) ans += query(ls, l, r);
    if (r > mid) ans += query(rs, l, r);
    return ans;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    n = read(), m = read();
    rep(i, 1, n) a[i] = read();
    build(1, 1, n);
    while (m--) {
        int opt, x, y, k;
        opt = read(), x = read(), y = read();
        if (opt == 1) {
            k = read();
            change(1, x, y, k);
        }
        else {
            write(query(1, x, y));
            printf("\n");
        }
    }
    return 0;
}
View Code

 

区间动规:

rep(len, 2, n)
    rep(l, 1, n - len + 1) {
        int r = l + len - 1
        rep(k, l, r - 1)
            dp[l][r] = min(dp[l][r], dp[l][k] + dp[k + 1][r] + val[l][r]);
    }
write(dp[1][n]);
View Code

 

数位动规:

顺序版:

//大部分时候DP数组并不需要开这么多维. 
void origin() {memset(dp, -1, sizeof(dp));}

int dfs(int pos, int pre, int limit, int lead, int sum, ???) {
    int ans = 0;
    if (pos > len) return sum;
    if (!limit && !lead && dp[pos][pre][limit][lead][sum]??? != -1) return dp[pos][pre][limit][lead][sum]???;
    int res = limit ? num[len - pos + 1] : 9;
    rep(i, 0, res) {
        if (!i && lead) ans += dfs(pos + 1, 0, (i == res) && limit, 1, ???);
        else if (i && lead) ans += dfs(pos + 1, i, (i == res) && limit, 0, ???);
        else if(/*根据题意而设的判断*/) ans += dfs(pos + 1, ???, (i == res) && limit, ???);
    }
    if (!limit && !lead) dp[pos][pre][limit][lead][sum]??? = ans;
    return ans;
}

int work(int x) {
    origin();
    len = 0;
    while(x) {
        num[++len] = x % 10;
        x /= 10;
    }
    return dfs(1, 0, 1, 1, 0, ???);
}

int main() {
    l = read(), r = read();
    write(work(r) - work(l - 1));
    return 0;
}
View Code

 

逆序版:

//大部分时候DP数组并不需要开这么多维. 
void origin() {memset(dp, -1, sizeof(dp));}

int dfs(int pos, int pre, int limit, int lead, int sum, ???) {
    int ans = 0;
    if (!pos) return sum;
    if (!limit && !lead && dp[pos][pre][limit][lead][sum]??? != -1) return dp[pos][pre][limit][lead][sum]???;
    int res = limit ? num[pos] : 9;
    rep(i, 0, res) {
        if (!i && lead) ans += dfs(pos - 1, 0, (i == res) && limit, 1, ???);
        else if (i && lead) ans += dfs(pos - 1, i, (i == res) && limit, 0, ???);
        else if(/*根据题意而设的判断*/) ans += dfs(pos - 1, ???, (i == res) && limit, ???);
    }
    if (!limit && !lead) dp[pos][pre][limit][lead][sum]??? = ans;
    return ans;
}

int work(int x) {
    origin();
    len = 0;
    while(x) {
        num[++len] = x % 10;
        x /= 10;
    }
    return dfs(len, 0, 1, 1, 0, ???);
}

int main() {
    l = read(), r = read();
    write(work(r) - work(l - 1));
    return 0;
}
View Code

 

网络最大流:

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (register int i = (a); i <= (b); i++)

const int inf = 0x3f3f3f3f, maxn = 1e4 + 5, maxm = 1e6 + 5;

int n, m, S, T, ans = 0, num_edge = -1;
int cur[maxn], dep[maxn], head[maxn];

queue<int> q;

struct node {
    int to, dis, nxt;
}edge[maxm];

void origin() {memset(head, -1, sizeof(head));}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return flag ? -x : x;
}

void addedge(int from, int to, int dis) {
    edge[++num_edge].nxt = head[from];
    edge[num_edge].to = to;
    edge[num_edge].dis = dis;
    head[from] = num_edge;
}

int bfs(int S, int T) {
    memset(dep, 0, sizeof(dep));
    while (!q.empty()) q.pop();
    memcpy(cur, head, sizeof(cur)); 
    dep[S] = 1;
    q.push(S);
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        for (int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (!dep[v] && edge[i].dis) {
                dep[v] = dep[u] + 1;
                q.push(v);
            }
        }
    }
    if (dep[T]) return 1;
    return 0;
}

int dfs(int u, int flow) {
    if (u == T || !flow) return flow;
    int d, used = 0;
    for (int i = cur[u]; ~i; i = edge[i].nxt) {
        cur[u] = i;
        int v = edge[i].to;
        if (dep[v] == dep[u] + 1 && (d = dfs(v, min(flow, edge[i].dis)))) {
            used += d;
            flow -= d;
            edge[i].dis -= d;
            edge[i ^ 1].dis += d;
            if (!flow) break;
        }
    }
    if (!used) dep[u] = -2;
    return used;
}

int dinic() {
    int ans = 0;
    while (bfs(S, T)) ans += dfs(S, inf);
    return ans;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    origin();
    n = read(), m = read(), S = read(), T = read();
    rep(i, 1, m) {
        int u, v, w;
        u = read(), v = read(), w = read();
        addedge(u, v, w);
        addedge(v, u, 0);
    }
    ans = dinic();
    write(ans);
    return 0;
}
View Code

 

二叉搜索树 & 平衡树:

splay:

#include <bits/stdc++.h>
using namespace std;

const int maxn = 1e5 + 5;

int n, rt, tot = 0;
int fa[maxn], cnt[maxn], val[maxn], size[maxn], ch[maxn][3];

int get(int x) {return x == ch[fa[x]][1];}
    
void maintain(int x) {size[x] = size[ch[x][0]] + size[ch[x][1]] + cnt[x];}
    
void clear(int x) {fa[x] = cnt[x] = val[x] = size[x] = ch[x][0] = ch[x][1] = 0;}

int pre() {
    int cnr = ch[rt][0];
    if (!cnr) return -1;
    while (ch[cnr][1]) cnr = ch[cnr][1];
    return cnr;
}
    
int nxt() {
    int cnr = ch[rt][1];
    if (!cnr) return -1;
    while (ch[cnr][0]) cnr = ch[cnr][0];
    return cnr;
}
   
void rotate(int x) {
    int y = fa[x], z = fa[y], chk = get(x);
    ch[y][chk] = ch[x][chk ^ 1];
    fa[ch[x][chk ^ 1]] = y;
    ch[x][chk ^ 1] = y;
    fa[y] = x;
    fa[x] = z;
    if (z) ch[z][y == ch[z][1]] = x;
    maintain(y);
    maintain(x);
}

void splay(int x) {
    for (register int f = fa[x]; f = fa[x], f; rotate(x))
        if (fa[f]) rotate(get(x) == get(f) ? f : x);
    rt = x;
}

int kth(int k) {
    int cnr = rt;
    while (1) {
        if (ch[cnr][0] && k <= size[ch[cnr][0]]) cnr = ch[cnr][0];
        else {
            k -= cnt[cnr] + size[ch[cnr][0]];
            if (k <= 0) return val[cnr];
            cnr = ch[cnr][1];
        }
    }
}

int rk(int k) {
    int res = 0, cnr = rt;
    while (1) {
        if (k < val[cnr]) cnr = ch[cnr][0];
        else {
            res += size[ch[cnr][0]];
            if (k == val[cnr]) {
                splay(cnr);
                return res + 1;
            }
            res += cnt[cnr];
            cnr = ch[cnr][1];
        }
    }
}

void insert(int k) {
    if (!rt) {
        val[++tot] = k;
        cnt[tot]++;
        rt = tot;
        maintain(rt);
        return;
    }
    int f = 0, cnr = rt;
    while (1) {
        if (k == val[cnr]) {
            cnt[cnr]++;
            maintain(cnr);
            maintain(f);
            splay(cnr);
            break;
        }
        f = cnr;
        cnr = ch[cnr][val[cnr] < k];
        if (!cnr) {
            val[++tot] = k;
            cnt[tot]++;
            fa[tot] = f;
            ch[f][val[f] < k] = tot;
            maintain(tot);
            maintain(f);
            splay(tot);
            break;
        }
    }
}

void erase(int k) {
    rk(k);
    if (cnt[rt] > 1) {
        cnt[rt]--;
        maintain(rt);
        return;
    }
    if (!ch[rt][0] && !ch[rt][1]) {
        clear(rt);
        rt = 0;
        return;
    }
    if (!ch[rt][0]) {
        int cnr = rt;
        rt = ch[rt][1];
        fa[rt] = 0;
        clear(cnr);
        return;
    }
    if (!ch[rt][1]) {
        int cnr = rt;
        rt = ch[rt][0];
        fa[rt] = 0;
        clear(cnr);
        return;
    }
    int x = pre(), cnr = rt;
    splay(x);
    fa[ch[cnr][1]] = x;
    ch[x][1] = ch[cnr][1];
    clear(cnr);
    maintain(rt);
}

int read() {
    int x = 0, flag = 0;
    char ch = ' ';
    while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if (ch == '-') {
        flag = 1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ '0');
        ch = getchar();
    }
    return flag ? -x : x;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int main() {
    n = read();
    while (n--) {
        int x, opt;
        opt = read(), x = read();
        if (opt == 1) insert(x);
        else if (opt == 2) erase(x);
        else if (opt == 3) {
            write(rk(x));
            printf("\n");
        }
        else if (opt == 4) {
            write(kth(x));
            printf("\n");
        }
        else if (opt == 5) {
            insert(x);
            write(val[pre()]);
            printf("\n");
            erase(x);
        }
        else {
            insert(x);
            write(val[nxt()]);
            printf("\n");
            erase(x);
        }
    }
    return 0;
}
View Code

 对不起,这篇文章暂时鸽了。

posted @ 2019-06-11 17:01  雲裏霧裏沙  阅读(343)  评论(0编辑  收藏  举报