算法模板By Roshin

可前往 Gitee 码云 获取.md文件,如果可以的话麻烦点个 \(star\) 或者 \(fork\) 咯~

memset也会卡时间
\((x\mod a) \mod ba = (x \mod ba )\mod a\)
\(1~n\) 中有多少数是 \(i\) 的倍数,就是 \(n / i\)
数组非全局变量要初始化
多动手推一推
复杂变量学会适当引用
\(\Sigma_{i=0}^n{C_n^i * i} = n*2^{n-1}\)

目录

STL

汇总

vector, 变长数组,倍增的思想
    size()  返回元素个数
    empty()  返回是否为空
    clear()  清空
    front()/back()
    push_back()/pop_back()
    begin()/end()
    []
    支持比较运算,按字典序

pair<int, int>
    first, 第一个元素
    second, 第二个元素
    支持比较运算,以first为第一关键字,以second为第二关键字(字典序)

string,字符串
    size()/length()  返回字符串长度
    empty()
    clear()
    substr(起始下标,(子串长度))  返回子串
    c_str()  返回字符串所在字符数组的起始地址

queue, 队列
    size()
    empty()
    push()  向队尾插入一个元素
    front()  返回队头元素
    back()  返回队尾元素
    pop()  弹出队头元素

priority_queue, 优先队列,默认是大根堆
    size()
    empty()
    push()  插入一个元素
    top()  返回堆顶元素
    pop()  弹出堆顶元素
    定义成小根堆的方式:priority_queue<int, vector<int>, greater<int>> q;

stack, 栈
    size()
    empty()
    push()  向栈顶插入一个元素
    top()  返回栈顶元素
    pop()  弹出栈顶元素

deque, 双端队列
    size()
    empty()
    clear()
    front()/back()
    push_back()/pop_back()
    push_front()/pop_front()
    begin()/end()
    []

set, map, multiset, multimap, 基于平衡二叉树(红黑树),动态维护有序序列
    size()
    empty()
    clear()
    begin()/end()
    ++, -- 返回前驱和后继,时间复杂度 O(logn)
    next(), prev(), 与上同理

set/multiset
    insert()  插入一个数
    find()  查找一个数
    count()  返回某一个数的个数
    erase()
        (1) 输入是一个数x,删除所有x   O(k + logn)
        (2) 输入一个迭代器,删除这个迭代器
    lower_bound()/upper_bound()
        lower_bound(x)  返回大于等于x的最小的数的迭代器
        upper_bound(x)  返回大于x的最小的数的迭代器
map/multimap
    insert()  插入的数是一个pair
    erase()  输入的参数是pair或者迭代器
    find()
    []  注意multimap不支持此操作。 时间复杂度是 O(logn)
    lower_bound()/upper_bound()

unordered_set, unordered_map, unordered_multiset, unordered_multimap, 哈希表
    和上面类似,增删改查的时间复杂度是 O(1)
    不支持 lower_bound()/upper_bound(), 迭代器的++,--

bitset, 圧位
    bitset<10000> s;
    ~, &, |, ^
    >>, <<
    ==, !=
    []

    count()  返回有多少个1

    any()  判断是否至少有一个1
    none()  判断是否全为0

    set()  把所有位置成1
    set(k, v)  将第k位变成v
    reset()  把所有位变成0
    flip()  等价于~
    flip(k) 把第k位取反

lower_bound

搜索数组中第一个大于等于 ≥ x的数,返回迭代器

不存在返回end

#include<algorithm>
vector<int> a;
int pos = lower_bound(a.begin(), a.end(), x) - a.begin();	// pos便是目标下标
int pos = lower_bound(a.begin(), a.end(), x, greater<int>() ) - a.begin()  // 找到第一个小于等于x的数

upper_bound

与lower_bound用法相似,搜索数组中第一个大于 > x的数,返回迭代器

不存在返回end

#include<algorithm>
vector<int> a;
// pos便是目标下标
int pos = upper_bound(a.begin(), a.end(), x) - a.begin();	
// 找到第一个小于x的数
int pos = upper_bound(a.begin(), a.end(), x, greater<int>() ) - a.begin();	

基础算法

分治

分治法把一个问题划分为若干个规模更小的同类子问题,对这些子问题递归求解,然后再回溯时通过它们推导出原问题的解。

分治求等比数列和

用到了快速幂,求 \(1+p+\cdots+p^k\) 的和

// 分治求等比数列和复杂度(log k)
ll sum(int p, int k){
    if(!k) return 1;
    if(k & 1)
        return (1 + qmi(p, (k + 1) / 2)) % mod * sum(p, (k - 1) / 2) % mod;
    return ((1 + qmi(p, k / 2)) * sum(p, k / 2 - 1) + qmi(p, k)) % mod;
}

分治解决最近点问题

#define pb push_back
#define x first
#define y second
#define mkp make_pair
#define endl "\n"
using namespace std;
// 最近点问题,分治解决:按x坐标排序,ans = min(左右两边最短距离,中间线旁边点的最短距离)
// 结论:中间线旁边两点最短距离,最多只需要找6个点
const int N = 2e5 + 10;
const double INF = 1e9 + 10, eps = 1e-5;
int n;
double min_d;
struct P{
    int x, y;
    bool type;
    bool operator < (const P& a)const{
        return x < a.x;
    }
}p[N];

double get(P& a, P& b){
    if(a.type == b.type) return min_d;
    double dx = a.x - b.x, dy = a.y - b.y;
    return sqrt(dx * dx + dy * dy);
}

double dfs(int l, int r){
    if(l == r) return min_d;
    int mid = (l + r) >> 1;
    double ans = min(dfs(l, mid), dfs(mid + 1, r));
    double midx = p[mid].x;
    int i = l, j = mid + 1, cnt = 0;
    P tmp[N];
    // 按纵坐标大小归并排序
    while(i <= mid && j <= r){
        if(p[i].y <= p[j].y) tmp[cnt++] = p[i++];
        else tmp[cnt++] = p[j++];
    }
    while(i <= mid) tmp[cnt++] = p[i++];
    while(j <= r) tmp[cnt++] = p[j++];
    for(int i = l; i <= r; i++)
        p[i] = tmp[i - l];
    cnt = 0;
    // 找出满足要求的点,距离中线最多ans个距离
    for(int i = l; i <= r; i++)
        if(midx - ans <= p[i].x && p[i].x <= ans + midx)
            tmp[cnt++] = p[i];
    // 一定判断纵坐标在 ans 范围内,才能保证复杂度,要加 eps 否则TLE
    for(int i = 0; i < cnt; i++)          
        for(int j = i - 1; j >= 0 && tmp[i].y - tmp[j].y + eps <= ans; j--)
            ans = min(ans, get(tmp[i], tmp[j]));
    min_d = min(min_d, ans);
    return ans;
}

int main(){
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    int T;
    cin >> T;
    while(T--){
        min_d = INF;
        cin >> n;
        for(int i = 0; i < n; i++){
            cin >> p[i].x >> p[i].y;
            p[i].type = true;
        }
        for(int i = n; i < 2 * n; i++){
            cin >> p[i].x >> p[i].y;
            p[i].type = false;
        }
        sort(p, p + 2 * n);
        min_d = get(p[0], p[2 * n - 1]);
        double ans = dfs(0, 2 * n - 1);
        cout << fixed << setprecision(3) << ans << endl;
    }
    return 0;
}

基础排序算法

快速排序

#include<iostream>
using namespace std;
const int N = 1e6 + 10;
int q[N], n;
void quick_sort(int l,int r){
    if(l >= r) return;
    int x = q[l];
    int i = l - 1,j = r + 1;
    while(i < j){
        do i++;while (q[i] < x);
        do j--;while (q[j] > x);
        if(i < j) swap(q[i], q[j]);
    }
    quick_sort(l,j);
    quick_sort(j + 1,r);
}

int main(){
    scanf("%d",&n);
    for(int i = 0;i < n;i++)
        scanf("%d", &q[i]);
    quick_sort(0, n - 1);
    for(int i = 0;i < n;i++)
        printf("%d ",q[i]);
    return 0;
}

归并排序

递归分成相等的子段,子段内部排序后,回溯时合并再排序

const int N = 1e6 + 10;
int q[N],tmp[N];
void mergesort(int a[],int l,int r){
    if(l >= r) return;
    int mid = (l + r) / 2;
    mergesort(q,l,mid);
    mergesort(q,mid+1,r);
    int k = 0,i = l,j = mid + 1;
    while(i <= mid && j <= r){
        if(q[i] <= q[j]) tmp[k++] = q[i++];
        else tmp[k++] = q[j++]; // 求逆序对,在后面加一个 ans += mid - i + 1;
    }
    while(i <= mid) tmp[k++] = q[i++];
    while(j <= r) tmp[k++] = q[j++];
    for(i = l,j = 0;i <= r;i++,j++) q[i] = tmp[j];  
}

int main(){
    int n;
    scanf("%d",&n);
    for(int i = 0;i < n;i++){
        scanf("%d",&q[i]);
    } 
    mergesort(q,0,n-1);
    for(int i = 0;i < n;i++){
        printf("%d ",q[i]);
    } 
    return 0;
}

二分

整数二分

bool check(int x) {/* ... */} // 检查x是否满足某种性质

// 区间[l, r]被划分成[l, mid]和[mid + 1, r]时使用:
int bsearch_1(int l, int r)
{
    while (l < r)
    {
        int mid = l + r >> 1;
        if (check(mid)) r = mid;    // check()判断mid是否满足性质
        else l = mid + 1;
    }
    return l;
}
// 区间[l, r]被划分成[l, mid - 1]和[mid, r]时使用:
int bsearch_2(int l, int r)
{
    while (l < r)
    {
        int mid = l + r + 1 >> 1;
        if (check(mid)) l = mid;
        else r = mid - 1;
    }
    return l;
}

浮点数二分

bool check(double x) {/* ... */} // 检查x是否满足某种性质

double bsearch_3(double l, double r)
{
    const double eps = 1e-6;   // eps 表示精度,取决于题目对精度的要求
    while (r - l > eps)
    {
        double mid = (l + r) / 2;
        if (check(mid)) r = mid;
        else l = mid;
    }
    return l;
}

前缀和、差分

核心思想

前缀和将查询区间和变为 \(O(1)\)

差分将区间修改变为 \(O(1)\)\(O(n)\) 得到原数组

一维前缀和

for(int i = 1; i <= n; i++)
    s[i] = s[i - 1] + a[i];

二维前缀和

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= n; j++)
        s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + a[i][j];

一维差分

b[i] = a[i] - a[i - 1]
// 对 [l, r] 区间 + c: b[l] += c, b[r + 1] -= c;
// 对 [l, r] 区间 - c: b[l] -= c, b[r + 1] += c;

二维差分

// b[x1][y1] += c 是对顶点到右下角的子矩阵的值增加 c
void insert(int x1, int y1, int x2, int y2, int c){     
    b[x1][y2 + 1] -= c;
    b[x2 + 1][y1] -= c;
    b[x2 + 1][y2 + 1] += c;
    b[x1][y1] += c;
}

双指针

核心思想
省掉重复性的动作,进行优化,能达到 \(O(n)\) 复杂度

区间离散化

unique() + erase() 函数

vector<int> v;
sort(v.begin(), v.end());       // unique仅对相邻元素处理,需要排序
v.erase(unique(v.begin(), b.end()), v.end());   // unique返回末尾重复元素开头位置

贪心

区间选点

给定 \(N\) 个闭区间 \([a_i,b_i]\),请你在数轴上选择尽量少的点,使得每个区间内至少包含一个选出的点。

输出选择的点的最小数量。

位于区间端点上的点也算作区间内。

#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e5+10;
struct Range{
    int l,r;
    bool operator < (const Range & W){ //重载 < 符号方便以右端点直接sort排序
        return r < W.r;
    }
}range[N];

int main(){
    int n;
    cin >> n;
    for(int i = 0;i < n;i++) cin >> range[i].l >> range[i].r; 
    sort(range,range+n);
    int res = 0,ed = -2e9;
    for(int i = 0;i < n;i++){
        if(range[i].l > ed){        // 后一区间 左端点 大于 ed 值,需要新的点
            res++;
            ed = range[i].r;
        }
    }
    printf("%d",res);
    return 0;
}

最大不相交区间数量

给定 \(N\) 个闭区间 \([ai,bi]\),请你在数轴上选择若干区间,使得选中的区间之间互不相交(包括端点)。

输出可选取区间的最大数量。

#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e5 + 10;
struct Range{
    int l,r;
    bool operator < (const Range & w){
        return l < w.l;
    }
}range[N];

int main(){
    int n;
    cin >> n;
    for(int i = 0;i < n;i++){
        cin >> range[i].l >> range[i].r;
    }
    sort(range,range+n);
    int res = 0,ed = -2e9;
    for(int i = 0;i < n;i++){
        if(range[i].l <= ed) ed = min(range[i].r,ed);
        else{
            res++;
            ed = range[i].r;
        }
    }
    printf("%d",res);
    return 0;
}

区间分组

给定 N 个闭区间 [ai,bi],请你将这些区间分成若干组,使得每组内部的区间两两之间(包括端点)没有交集,并使得组数尽可能小。

输出最小组数。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
// 贪心:左端点排序,用小根堆记录所有组的max_r,对于一个组,若l[i] >= max_r,则将区间放进(随意组),不成立则换其他组放,无组可放,则放入新的组。
using namespace std;
const int N = 1e5 + 10;
int n;

struct Range{
    int l,r;
    bool operator < (const Range & a) const{
        return l < a.l;
    }
}range[N];

int main(){
    scanf("%d",&n);
    for(int i = 0;i < n;i++){
        int l,r;
        scanf("%d%d",&l,&r);
        range[i] = {l,r};
    }
    sort(range,range+n);        // 左端点排序
    priority_queue<int, vector<int>, greater<int>> heap;
    for(int i = 0;i < n;i++){
        auto r = range[i];
        if(heap.empty() || heap.top() >= r.l) heap.push(r.r);       // 没有组成立或不存在组能放进去
        else{
            heap.pop();     // 因为新放入的区间r 一定大于原来的max_r,所以直接pop掉根
            heap.push(r.r);
        }
    }
    printf("%d",heap.size());
    return 0;
}

区间覆盖

给定 \(N\) 个闭区间 \([a_i,b_i]\) 以及一个线段区间 \([s,t]\) ,请你选择尽量少的区间,将指定线段区间完全覆盖。

输出最少区间数,如果无法完全覆盖则输出 −1。

#include<iostream>
#include<algorithm>
using namespace std;
const int N = 1e5 + 10;
int n;
// 贪心:区间按左端点排序,从左往右枚举 左端点小于st 的区间,取到最大的r,若r < st则不能覆盖,然后更新st = r,重复操作直到r >= ed;
struct Range{
    int l,r;
    bool operator < (const Range & w){
        return l < w.l;
    }
}range[N];

int main(){
    int st,ed;
    scanf("%d%d",&st,&ed);
    scanf("%d",&n);
    for(int i = 0;i < n;i++){
        int l,r;
        scanf("%d%d",&l,&r);
        range[i] = {l,r};
    }
    sort(range,range + n);
    bool success = false;       // 检验是否覆盖
    int res = 0;
    for(int i = 0;i < n;i++){
        int r = -2e9;       // r最大值
        int j = i;
        while(j < n && range[j].l <= st){       // 遍历所有 l < st的区间
            r = max(r, range[j].r);
            j++;
        }
        if(r < st){         // 不能完全覆盖
            res = -1;
            break;
        }
        
        res++;
        if(r >= ed){        // 区间覆盖完了
            success = true;
            break;
        }
        
        st = r;         // 更新st
        i = j - 1;      // i取到最后一次取到的区间
    }
    if(!success) res = -1;
    printf("%d",res);
    return 0;
}

数据结构

STL基础

队列

滑动窗口

滑动窗口本质是一个单调的双端队列

#include<deque>
using namespace std;
const int N = 1e6 + 10;
deque<int> q;
int a[N];
int n, k;

// 滑动窗口要用双端队列
int main(){
    cin >> n >> k;      // k 为滑动窗口大小
    for(int i = 1; i <= n; i++){
        cin >> a[i];
    }
    for(int i = 1; i <= n; i++){        // 获取滑动窗口最小值
        int x = a[i];
        while(!q.empty() && i - q.front() > k - 1)
            q.pop_front();
        while(!q.empty() && a[q.back()] >= a[i]){       // 队尾大于等于 x 就 pop_back
            q.pop_back();
        }
        q.push_back(i);
        if(i >= k)
            printf("%d ", a[q.front()]);
    }
    printf("\n");
    q.clear();
    for(int i = 1; i <= n; i++){        // 获取滑动窗口最大值
        int x = a[i];
        while(!q.empty() && i - q.front() > k - 1)
            q.pop_front();
        while(!q.empty() && a[q.back()] <= a[i]){       // 队尾小于 x 就 pop_back
            q.pop_back();
        }
        q.push_back(i);
        if(i >= k)
            printf("%d ", a[q.front()]);
    }
    
    return 0;
}

数组模拟队列

int q[N];

int main(){
    int T;
    cin >> T;
    int hh = 0, tt = -1;
    while(T--){
        string op;
        cin >> op;
        if(op == "push"){       // 加入元素到队尾
            int x;
            cin >> x;
            q[++tt] = x;
        }
        else if(op == "pop")    // 弹出队头
            hh++;
        else if(op == "empty"){ // 查询是否为空
            if(hh <= tt)
                puts("NO");
            else
                puts("YES");
        }
        else        // 查询队头
            printf("%d\n", q[hh]);
    }   
    return 0;
}

数组模拟栈

int s[N];

int main(){
    int T;
    cin >> T;
    int tt = 0;     // tt 表示栈顶
    while(T--){
        string op;
        cin >> op;
        if(op == "push"){       // 入栈
            int x;
            cin >> x;
            s[++tt] = x;
        }
        if(op == "pop")     // 弹出栈顶
            tt--;
        if(op == "empty"){      // 查询栈是否为空
            if(tt > 0)  
                puts("NO");
            else
                puts("YES");
        }
        if(op == "query")       // 查询栈顶
            printf("%d\n", s[tt]);
    }
    
    return 0;
}

单调栈

int s[N];

int main(){
    int n;
    cin >> n;
    int tt = 0;
    for(int i = 0; i < n; i++){
        int x;
        cin >> x;
        while(tt > 0 && s[tt] >= x)     // 栈不为空,栈顶大于等于 x
            tt--;
        if(!tt)
            printf("-1 ");
        else
            printf("%d ", s[tt]);
        s[++tt] = x;        // 将 x 入栈
    }
    
    return 0;
}

笛卡尔树

struct CartesianTree {
    const static int maxn = 1e6 + 10;
    int stk[maxn], l[maxn], r[maxn], n, top;
    /* root is stk[0] */
    void build (int* a, int _n) {
        top = 0, n = _n;        
        for (int i = 1; i <= n; ++ i) l[i] = r[i] = 0;
        for (int i = 1; i <= n; ++ i) {
            int k = top;
            while (k && a[stk[k - 1]] > a[i]) --k;
            if (k) r[stk[k - 1]] = i;
            if (k < top) l[i] = stk[k];
            stk[k++] = i;
            top = k;
        }
        // for (int i = 1; i <= n; ++ i) {
        //     if (l[i]) add(i, l[i]);     // add edge
        //     if (r[i]) add(i, r[i]);
        // }
    }
}ctr;

Trie 树

\(son\) 数组的大小要根据题目而定,一般为 \(元素个数 \times 存储每个元素最多需要的节点数\)

int son[N][26], idx;        // idx存放节点编号, 0号节点既是根也是空节点
int cnt[N];         

void insert(char str[]){    // 插入
    int p = 0;
    for(int i = 0; str[i]; i++){
        int u = str[i] - 'a';
        if(!son[p][u]) son[p][u] = ++ idx;      // 节点不存在则新建一个
        p = son[p][u];
    }
    cnt[p] ++;  // 对末尾节点做标记
}

int query(char str[]){      // 查询
    int p = 0;
    for(int i = 0; str[i]; i++){
        int u = str[i] - 'a';
        if(!son[p][u])  return 0;
        p = son[p][u];
    }
    return cnt[p];
}

求异或 \(x\)\(k\)

是一种广义权值线段树

#include<bits/stdc++.h>
typedef long long ll;
#define pb push_back
#define endl "\n"
using namespace std;
/*----------------------------------------------------------------------------------------------------*/

/* 广义权值线段树,也可以看做是 0-1 Trie
   每个叶子节点的路径代表一个数。
   查询 ^x 的第 k 小,就拿 x 在树上跑,类似线段树二分对比左右子树大小 */

const int N = 2e5 + 10, M = 29;
struct Node{
    int s[2];
    int sz;
}tr[N * 32];

int idx, root;      // 0 为空 

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    root = ++ idx;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        int p = root;
        for (int j = M; ~j; j--) {
            int x = a[i] >> j & 1;
            tr[p].sz++;
            if (!tr[p].s[x]) tr[p].s[x] = ++ idx;
            p = tr[p].s[x];
        }
        tr[p].sz++;     // 结尾节点计数 + 1
    }
    while (m--) {
        int x, k;
        cin >> x >> k;
        int p = root;
        ll ans = 0;
        for (int i = M; ~i; i --) {
            int val = x >> i & 1;
            if (tr[tr[p].s[val]].sz >= k) p = tr[p].s[val];
            else {
                k -= tr[tr[p].s[val]].sz;
                p = tr[p].s[val ^ 1];
                ans ^= 1 << i;
            }
            assert(p != 0);
        }
        cout << ans << endl;
    }
    return 0;
}

可持久化Trie

  • 空间不够,用最大空间倒退,能开多大开多大
    例题: 最大异或和
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 6e5 + 10, M = N * 24;     // 原来要3e5, 3e5次查询 序列长度最多有6e5,每个数最多24位,历史版本Trie所有加起来最多需要 6e5 * 24 个节点
int n, m;
int tr[M][2], root[N], max_id[M], s[N], idx;    

// 可持久化Trie思想:与前一个版本的trie不同的路径上的点都要重新建立
// 题目思路:转换异或前缀和,s[p - 1] ^ s[n] ^ x 最大,在[l - 1, r - 1]找一个 p 使得 s[p - 1] ^ (s[n] ^ x) 最大

// 因为需要先知道子节点下标来确定父节点下标,用递归的写法
void insert(int i, int k, int p, int q){        // 在前缀和中的下标,插入了第几位(从左到右),旧版本对应节点,新版本对应节点
    if(k <= -1){
        max_id[q] = i;
        return ;
    }
    int v = (s[i] >> k) & 1;        // 取出该位
    tr[q][v] = ++ idx;      // 与原来不同的路径要新开点
    if(p)
        tr[q][v ^ 1] = tr[p][v ^ 1];                            // 与原来相同的路径直接复制
    insert(i, k - 1, tr[p][v], tr[q][v]);                       // 递归插入 
    max_id[q] = max(max_id[tr[q][0]], max_id[tr[q][1]]);        // 回溯更新子树最大下标
}

int query(int root, int C, int l){              // 根节点编号,待异或的值,下标限制
    int p = root;
    for(int i = 23; i >= 0; i--){
        int v = (C >> i) & 1;
        if(max_id[tr[p][v ^ 1]] >= l)           // 贪心寻找01相反的节点,且下标符合要求
            p = tr[p][v ^ 1];
        else
            p = tr[p][v];
    }
   return C ^ s[max_id[p]];
}

int main(){
    cin >> n >> m;
    max_id[0] = -1;     // 空节点下标无穷小
    root[0] = ++ idx;   // 一定为根开辟新节点 
    insert(0, 23, 0, root[0]);      // 插入,s[0]也算前缀和
    int x;
    for(int i = 1; i <= n; i++){
        cin >> x;
        s[i] = s[i - 1] ^ x;
        root[i] = ++ idx;       // 为新根开辟新节点
        insert(i, 23, root[i - 1], root[i]);
    }
    while(m--){
        string op;
        cin >> op;
        if(op == "A"){
            ++n;
            cin >> x;
            root[n] = ++ idx;       // 为新根开辟新节点
            s[n] = s[n - 1] ^ x;
            insert(n, 23, root[n - 1], root[n]);
        }
        else{
            int l, r;
            cin >> l >> r >> x;
            cout << query(root[r - 1], s[n] ^ x, l - 1) << endl;    // 区间[l, r],转换后找 p 属于 [l - 1, r - 1], 在root[r - 1]中寻找,下标限制 >= l - 1;
        }
    }
    return 0;
}

并查集

基础操作

并查集初始化

for(int i = 1; i <= n; i++){
    p[i] = i;
    sz[i] = 1;    // 维护以i为根的连通块大小
}

查询+路径压缩

int find(int x){
    if(x != p[x]) return p[x] = find(p[x]);
    return p[x];
}

查询+维护点到根节点距离

int find(int x){
    if(x != p[x]){
        int root = find(p[x]);
        d[x] += d[p[x]];
        p[x] = root;
    }
    return p[x];
}

合并并查集

int pa = find(a), pb = find(b);
p[pa] = pb;
sz[pb] += sz[pa];     // 维护的并查集大小合并
d[pa] = d[b] + dist - d[a];     // 更新 pa 的边权

查询两点是否在同一集合

if(find(a) == find(b))

启发式合并

  • 由于合并时希望操作元素尽量少,就让少的往大的合并,这就是启发式合并
  • \(n\) 个元素和 \(m\) 次查询,时间复杂度为 \(O(mlogn)\)
// 启发式合并
void union(int x, int y){
    int fx = find(x), fy = find(y);
    if(fx == fy) return;
    if(sz[fx] > sz[fy])
        swap(fx, fy);
    p[fx] = fy;
    sz[fy] += sz[fx];
}

按深度合并

  • 每次合并将深度小的一方合并到深度大的一方
  • 路经压缩时,可能破坏深度值,复杂度不变差
// 按深度合并
void union(int x, int y){
    int fx = find(x), fy = find(y);
    if(fx == fy) return;
    if(dep[fx] > dep[fy])
        swap(fx, fy);
    p[fx] = fy;
    if(dep[fx] == dep[fy])  // 只有深度相等才更新
        dep[fy]++;
}

时间复杂度

  • 启发式合并和深度合并,\(n\) 个元素和 \(m\) 次查询,时间复杂度为 \(O(mlogn)\)

  • 一般来说并查集时间复杂度为 \(O(m*\alpha (m, n))\)。其中 \(\alpha\) 为阿克曼函数的反函数,可以认为是一个小常数

  • 无启发式合并,只路径压缩最坏时间复杂度为 \(O(mlogn)\),平均复杂度为 \(O*\alpha(m,n)\)

  • 可以直接认为 \(O(m)\)

哈希表

存储结构

开放寻址法

\(N\) 取输入规模的 \(2~3\) 倍,\(null\) 为初始化值

const int N = 2e5 + 5, null = 0x3f3f3f3f; 
int h[N];

void insert(int x){ // 插入
    int k = (x % N + N) % N;
    while(h[k] != null){
        k++;
        if(k == N)
            k = 0;
    }
    h[k] = x;
}

int find(int x){    // 查找
    int k = (x % N + N) % N;
    while(h[k] != null){
        if(h[k] == x)
            break;
        if(k == N)
            k = 0;
        k++;
    }
    return h[k];
}

拉链法

类似邻接表

const int N = 1e5 + 3;  // 大于输入数据规模的第一个质数

int e[N], ne[N], h[N], idx;

void insert(int x){     // 插入
    int k = (x % N + N) % N;
    e[idx] = x, ne[idx] = h[k], h[k] = idx++;
}

bool find(int x){      // 查找
    int k = (x % N + N) % N;
    for(int i = h[k]; i != -1; i = ne[i]){
        if(e[i] == x)
            return true;
    }
    return false;
}

ST表(RMQ算法)

与线段树区别:区间最值静态查询,不支持修改

一维ST表

预处理: \(O(nlogn)\), 查询:\(O(1)\)
ST表初始化

void init(){
    for(int i = 0; i < M; i++)  // 类似区间DP,先枚举区间长度,再枚举起点
        for(int j = 1; j + (1 << i) - 1 <= n; j++){
            if(!i)
                st[j][0] = a[j];
            else
                st[j][i] = max(st[j][i - 1], st[j + (1 << (i - 1))][i - 1]);
        }
}

ST表查询

int query(int l, int r){
    int len = r - l + 1;
    int k = log(len) / log(2);      // 小于区间长度的最大2的幂次
    return max(st[l][k], st[r - (1 << k) + 1][k]);      // 前后两段最大值的最大值就是区间的最大值
}

二维ST表

\(st[i][j][k][l]\) 表示以顶点 \((i,j)\) 为左上角边长为 \(2^k,2^l\) 的矩阵最值
预处理: \(O(n * m * logn * logm)\),查询复杂度 \(O(1)\)
还存在预处理:\(O(n*m*logn)\),查询 \(O(n)\) 做法。

int st[N][N][M][M], g[N][N];

void init(){		// O(n * m * logn * logm) --- O(1)
	for(int k = 0; k < M; k++)
		for(int l = 0; l < M; l++)
			for(int i = 1; i + (1 << k) - 1 <= n; i++)
				for(int j = 1; j + (1 << l) - 1 <= m; j++)
					if(!k && !l)
						st[i][j][k][l] = g[i][j];
					else if(!k && l)
						st[i][j][k][l] = max(st[i][j][k][l - 1] , st[i][j + (1 << (l - 1))][k][l - 1]);
					else if(k && !l)
						st[i][j][k][l] = max(st[i][j][k - 1][l], st[i + (1 << (k - 1))][j][k - 1][l]);
					else
						st[i][j][k][l] = max({st[i][j][k - 1][l - 1], st[i][j + (1 << (l - 1))][k - 1][l - 1],
											 st[i + (1 << (k - 1))][j][k - 1][l - 1],
											 st[i + (1 << (k - 1))][j + (1 << (l - 1))][k - 1][l - 1]});
}

int query(int x1, int y1, int x2, int y2){
	int len1 = x2 - x1 + 1, len2 = y2 - y1 + 1;
	int kx = log2(len1), ky = log2(len2);
	int mx1 = max(st[x2 - (1 << kx) + 1][y1][kx][ky], st[x1][y2 - (1 << ky) + 1][kx][ky]);
	int mx2 = max(st[x2 - (1 << kx) + 1][y2 - (1 << ky) + 1][kx][ky], st[x1][y1][kx][ky]);
	return max(mx1, mx2);
}

O(1) 求解LCA

// nlogn预处理 O(1)求lca
const int LOGN = 20;
// 欧拉序的长度要开两倍
PII f[LOGN + 2][2 * N];
int l[N], r[N], depth[N], tot;
// 求欧拉序列,访问的时候加入序列,访问完一个儿子回来的时候加入序列
void dfs(int u, int fa){
    l[u] = ++tot;
    depth[u] = depth[fa] + 1;
    f[0][tot] = {depth[u], u};
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        dfs(j, u);
        f[0][++tot] = {depth[u], u};
    }
    r[u] = tot;
}
// nlogn预处理出来lca
void init(){
    for(int i = 1; i <= LOGN; i ++){
        for(int j = 1; j + (1 << i) - 1 <= tot; j ++){
            f[i][j] = min(f[i - 1][j], f[i - 1][j + (1 << (i - 1))]);
        }
    }
}
// 求L和R的lca
if(L > R) swap(L, R);
int len = __lg(R - L + 1);
int lca = min(f[len][L], f[len][R - (1 << len) + 1]).second;

    res = min(res, min(val[0][v], val[0][u]));
    return res;

}

树状树组

本质功能及扩展

  • 区间查询,查询前缀和复杂度 \(O(logn)\) (可与差分结合)
  • 单点修改,修改数组元素复杂度 \(O(logn)\)
  • 可以以数值出现次数作为维护对象,解决逆序对问题
  • 在差分基础上可从单点查询改为区间和查询,方法是采用补集的思想:
    • 区间内每个元素是差分数组的前缀和,区间和就是两层循环的差分数组元素加和
    • 用补集,补出一个 (x + 1, x) 的矩阵,原来要求的区间和 = 矩阵和((x + 1) * 差分数组前缀和[x]) - i * a[i] 为元素数组的前缀和[x];
  • 处理数组第 \(k\) 小问题(数组为一个排列):
    • 树状树组维护一个元素值为1的数组,代表这个数出现了一次
    • 删除这个数便是,add(x, -1),找到第 k 小采用二分的方法,由于数组只有 0 和 1 ,前缀和有单调性,找到最小的 x,ask(x) = k,就是第 \(k\) 小的数
template<class T>
struct BIT {
    const static int maxn = 1e6 + 10;
	int n;
	T B[maxn];
	inline int lowbit(int x) { return x & (-x); }
    void init(int _n) {
        n = _n;
    }
	void add(int x) {
		for (int i = x; i <= n; i += lowbit(i)) B[i] += 1;
	}
	int ask(int x) {
		int res = 0;
		for (int i = x; i; i -= lowbit(i)) res += B[i];
		return res;
	}
};

\(O(logn)\) 实现树状数组上二分

template<class T>
struct BIT {
    const static int maxn = 1000010;
    T B[maxn];
	int n;
	void init(int _n){
		n = _n;
	}
	inline int lowbit(int x) { return x & (-x); }
	void add(int x, T v) {
		for(int i = x; i <= n; i += lowbit(i)) B[i] += v;
	}
	ll ask(ll s) {		// 查询前缀和小于 s 的最大下标
		int pos = 0;
		for(int j = 18; j >= 0; j--){
			if(pos + (1 << j) <= n && s >= B[pos + (1 << j)]){	// 在bit位上二分,复杂度 O(logn)
				pos += 1 << j;
				s -= B[pos];
			}
		}
		return pos;
	}
};

高维树状数组

template<class T>
struct mult_d_BIT {		// 二维树状数组,k维就开k个循环
	int n, m;
	vector<vector<T>> B;
	mult_d_BIT(int _n, int _m) : n(_n), m(_m), B(_n + 2, vector<T>(_m + 2, 0)) {}
	inline int lowbit(int x) { return x & (-x); }
	void add(int x, int y, T v) {
		for(int i = x; i <= n; i += lowbit(i))		// 循环变量代替x,y
			for(int j = y; j <= m; j += lowbit(j))
				B[i][j] += v;
	}
	T ask(int x, int y) {
		T res = 0;
		for(int i = x; i; i -= lowbit(i))
			for(int j = y; j; j -= lowbit(j))
				res += B[i][j];
		return res;
	}
};

线段树

五个函数: pushup pushdown build query modify
普通线段树:支持区间查询,单点修改
带有 \(lazy\) 标记线段树:同时支持区间修改需要pushdown操作

使用要点及扩展

  • 线段树维护属性根据题意来获取,现有属性不足就补充新的属性直到可以顺利更新为止
  • 同时维护区间加和区间乘的懒标记问题,先乘后加
  • 势能线段树:区间取模、开根,取最值(需要分析复杂度)
  • 维护区间最大公约数问题:
    • 区间加转化为单点修改(维护差分数组)
    • 区间最大公约数 \(gcd(a_1, a_2, ... , a_n) <=> gcd(a_1, a_2 - a_1, a_3 - a_2, ... , a_n - a_{n-1})\) 的最大公约数
    • 其中 \(a_1\) 用差分前缀和求出(需要维护区间和)
  • 区间异或操作每个节点维护区间异或结果的二进制数组,复杂度 \((O(nlogn*32/64))\)

结构体写法

struct SegTree{
    #define ls u << 1
    #define rs u << 1 | 1
    const static int maxn = 1000010;
    struct T{
        int l, r, v;
    }tr[maxn << 2];
    void update(T& rt) {

    }
    void pushup(int u){

    }
    void pushdown(int u){

    }
    void build(int u, int l, int r){        // 建立线段树,(节点编号,节点区间左端点,节点区间右端点)
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);       
        pushup(u);
    }   
    int query(int u, int l, int r){
        if(tr[u].l >= l && tr[u].r <= r){

            return ;
        }  
        else{
            pushdown(u);        // 递归分裂前pushdown
            int mid = (tr[u].l + tr[u].r) >> 1;
            int res = 0;
            if(l <= mid) res = query(ls, l, r);
            if(r > mid) res += query(rs, l, r);
            return res;
        }
    }
    void modify(int u, int pos, int v){         // 单点修改, (节点编号,查询点下标,更改值)
        if(tr[u].l == pos && tr[u].r == pos){   
           tr[u].v = v;
           return ;
        } 
        pushdown(u);
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(pos <= mid) modify(ls, pos, v);     
        else modify(rs, pos, v);               
        pushup(u);          // 子节点变化,pushup往父节点更新信息
    }   
    void modify(int u, int l, int r, int v){    // 区间修改
        if(tr[u].l >= l && tr[u].r <= r){       // 注意区间修改的递归出口

            return ;
        }
        pushdown(u);        // 递归分裂前 pushdown  
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) modify(ls, l, r, v);
        if(r > mid) modify(rs, l, r, v);
        pushup(u);
    }
};

线段树上二分

模板: 求区间 \([l,r]\) 中大于 \(d\) 的第一个下标, 复杂度 \(O(nlogn)\)

#include<bits/stdc++.h>
typedef long long ll;
#define endl "\n"
using namespace std;
const int N = 2e5 + 10;
ll a[N];

struct SegTree{
    #define ls u << 1
    #define rs u << 1 | 1
    const static int maxn = N;
    struct T{
        int l, r, v;
    }tr[maxn << 2];
    void pushup(int u){
        tr[u].v = max(tr[ls].v, tr[rs].v);
    }
    void build(int u, int l, int r){        // 建立线段树,(节点编号,节点区间左端点,节点区间右端点)
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            tr[u].v = a[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);       
        pushup(u);
    }   
    /* 线段树二分,l,r代表询问区间,与普通查询不同, 
       对于普通查询 l,r 在线段树向下查询的时候不变,向上 pushup 时将信息送给父节点再合并信息。
       对于线段数二分 l,r 查询是搜索答案,不合并子节点信息,递归解决子问题,必要的时候询问区间也要分裂
    */
    int search (int u, int l, int r, int d) {   // 
        if (tr[u].l == l && tr[u].r == r) {
            if (tr[u].v < d) return -1;     // 最大值小于 d,整段区间没有答案
            if (tr[u].l == tr[u].r) return tr[u].l;       // 找到叶子节点就是答案
            int mid = (tr[u].l + tr[u].r) >> 1;     // 中点是线段树中点,其实等价
            if (tr[ls].v >= d) return search(ls, l, mid, d);    // 询问区间也要分开
            return search(rs, mid + 1, r, d);
        }
        int mid = (tr[u].l + tr[u].r) >> 1; // 中点是线段树中点
        if (r <= mid)  return search(ls, l, r, d);      // 询问区间整个在左节点中,询问区间不拆
        else if (l > mid) return search(rs, l, r, d);
        // 询问区间横跨节点区间中点
        int pos = search(ls, l, mid, d);      
        if (pos == -1) pos = search(rs, mid + 1, r, d);
        return pos;
    }
    void modify(int u, int pos, int v){         // 单点修改, (节点编号,查询点下标,更改值)
        if(tr[u].l == pos && tr[u].r == pos){   
           tr[u].v = v;
           return ;
        } 
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(pos <= mid) modify(ls, pos, v);     
        else modify(rs, pos, v);               
        pushup(u);          // 子节点变化,pushup往父节点更新信息
    }   
}tr;


int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, q;
    cin >> n >> q;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    tr.build(1, 1, n);
    while (q--) {
        int op, x, d, l, r;
        cin >> op;
        if (op == 1) {
            cin >> x >> d;
            tr.modify(1, x, d);
        }
        else {
            cin >> l >> r >> d;
            cout << tr.search(1, l, r, d) << endl;
        }
    }
    return 0;
}

动态开点线段树

struct Dynamic_Memory_Tree {
    int node_cnt, root[N];
    struct T {
        int l, r;
        int w, mx;
    }tr[N * 40];
    int get_node() {
        node_cnt++;
        tr[node_cnt].l = tr[node_cnt].r = tr[node_cnt].w = 0;
        return node_cnt;
    }
    void pushup(T& rt) {
        rt.mx = max(tr[rt.l].mx, tr[rt.r].mx);
        rt.w = tr[rt.l].w + tr[rt.r].w;
    }
    void insert(int& q, int l, int r, int w, int pos) {
        if (!q) q = get_node();
        if (l == r) {
            tr[q].w = tr[q].mx = w;
            return ;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) insert(tr[q].l, l, mid, w, pos);
        else insert(tr[q].r, mid + 1, r, w, pos);
        pushup(tr[q]);
        return;
    }
    void modify(int u, int l, int r, int pos, int w) {
        if (l == r) {
            tr[u].w = tr[u].mx = w;
            return ;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) modify(tr[u].l, l, mid, pos, w);
        else modify(tr[u].r, mid + 1, r, pos, w);
        pushup(tr[u]);
        return;
    }
    int query_max(int u, int l, int r, int ql, int qr) {
        if (l > qr || r < ql) return 0;
        if (ql <= l && r <= qr) {
            return tr[u].mx;
        }
        int mid = (l + r) >> 1, res = 0;
        if (ql <= mid) res = query_max(tr[u].l, l, mid, ql, qr);
        if (qr > mid) res = max(res, query_max(tr[u].r, mid + 1, r, ql, qr));
        return res;
    }
    int query_sum(int u, int l, int r, int ql, int qr) {
        if (l > qr || r < ql) return 0;
        if (ql <= l && r <= qr) {
            return tr[u].w;
        }
        int mid = (l + r) >> 1, res = 0;
        if (ql <= mid) res = query_sum(tr[u].l, l, mid, ql, qr);
        if (qr > mid) res = res + query_sum(tr[u].r, mid + 1, r, ql, qr);
        return res;
    }
}dtr;

二维数点,扫描线

二维数点

#include<bits/stdc++.h>
typedef long long ll;
#define arr(x) (x).begin(),(x).end()
#define pb push_back
#define endl "\n"
using namespace std;

/* 二维数点: 1. 二维前缀和 + 容斥破除多重限制
            2. 离线询问,排序其中一维,将查询作为虚点,用数据结构维护
            3. 相同坐标,插入要比查询优先。
            4. 树状数组下标不为 0 */

template<class T>
struct BIT {
    const static int maxn = 2e5 + 10;
	int n;
	T B[maxn];
	inline int lowbit(int x) { return x & (-x); }
    void init(int _n) {
        n = _n;
    }
	void add(int x, int v) {
		for (int i = x; i <= n; i += lowbit(i)) B[i] += v;
	}
	T ask(int x) {
		T res = 0;
		for (int i = x; i; i -= lowbit(i)) res += B[i];
		return res;
	}
};

BIT<int> bt;

vector<int> alls;

int find(int x) {
    return upper_bound(arr(alls), x) - alls.begin();
}

struct P {
    int id, x, y;
    bool operator < (const P& p) const {
        if (y != p.y) return y < p.y;
        if (x != p.x) return x < p.x;
        return id < p.id;       
    }
};

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, q;
    cin >> n >> q;
    vector<P> v;
    alls.pb(0);
    for (int i = 0; i < n; i++) {
        int x, y;
        cin >> x >> y;
        v.pb({-1, x, y});
        alls.pb(x);
    }
    sort(arr(alls));
    alls.erase(unique(arr(alls)), alls.end());
    bt.init(alls.size());
    for (int i = 0; i < 4 * q; i += 4) {
        int x1, x2, y1, y2;
        cin >> x1 >> x2 >> y1 >> y2;
        v.pb({i, x2, y2}), v.pb({i + 1, x1 - 1, y1 - 1});
        v.pb({i + 2, x2, y1 - 1}), v.pb({i + 3, x1 - 1, y2});
    }
    sort(arr(v));
    vector<array<int, 4>> ans(q);
    for (auto & [id, x, y] : v) {
        if (id != -1) 
            ans[id / 4][id & 3] = bt.ask(find(x));
        else 
            bt.add(find(x), 1);
    }
    for (const auto & res: ans)  
        cout << res[0] + res[1] - res[2] - res[3] << endl;
    return 0;
}

扫描线

区间内不同数个数/和

#include<bits/stdc++.h>
typedef long long ll;
typedef unsigned long long ull;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
typedef double db;
#define arr(x) (x).begin(),(x).end()
#define x first
#define y second
#define pb push_back
#define endl "\n"
using namespace std;
/*----------------------------------------------------------------------------------------------------*/
/* 二维数点做法
    1. pos[i], a[i] 上一次出现的位置
    2. pos[i] < l, l <= i <= r
   扫描线做法
    1. 同样记录 pos[i],但有一个特殊的 ans[l] 数组。
    2. 我们遍历 r : 1->n, ans[l] 表示 [l,r] 的答案。
    3. r 从小到大。遇到 a[r],实际上对 ans[pos[a[r]] + 1, a[r]] += a[r], 在 [pos[a[r] + 1 , r] 这段区间 a[r] 没有出现过
    4. 然后询问离线,用树状数组维护差分,区间加,单点查,计算答案即可。
*/

template<class T>
struct BIT {
    const static int maxn = 2e5 + 10;
	int n;
	T B[maxn];
	inline int lowbit(int x) { return x & (-x); }
    void init(int _n) {
        n = _n;
    }
	void add(int x, int v) {
		for (int i = x; i <= n; i += lowbit(i)) B[i] += v;
	}
	T ask(int x) {
		T res = 0;
		for (int i = x; i; i -= lowbit(i)) res += B[i];
		return res;
	}
};
BIT<ll> bt;

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, q;
    cin >> n >> q;
    vector<int> a(n + 1, 0);
    for (int i = 1; i <= n; i++) 
        cin >> a[i];
    vector<vector<PII>> v(n + 1);
    vector<ll> pos(n + 1, 0), ans(n, 0);
    for (int i = 0; i < q; i++) {
        int l, r;
        cin >> l >> r;
        v[r].pb({l, i});
    }
    bt.init(n);
    for (int r = 1; r <= n; r ++) {
        int L = pos[a[r]];
        bt.add(L + 1, a[r]);
        bt.add(r + 1, -a[r]);
        pos[a[r]] = r;
        for (auto t: v[r])
            ans[t.second] = bt.ask(t.first);
    }
    for (auto t: ans)
        cout << t << endl;
    return 0;
}

区间mex

#include<bits/stdc++.h>
typedef long long ll;
typedef unsigned long long ull;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
typedef double db;
#define arr(x) (x).begin(),(x).end()
#define x first
#define y second
#define pb push_back
#define endl "\n"
using namespace std;
/*----------------------------------------------------------------------------------------------------*/

/* pos[x], x 最后一次出现的位置, 
   每次查询则求,最小的 x 满足 pos[x] < l, 实际是一个线段树二分问题。
   数值过大是没有意义的。 
*/

struct SegTree{
    #define ls u << 1
    #define rs u << 1 | 1
    const static int maxn = 2e5 + 10;
    struct T{
        int l, r, v;
    }tr[maxn << 2];
    void pushup(int u){
        tr[u].v = min(tr[ls].v, tr[rs].v);
    }
    void build(int u, int l, int r){        // 建立线段树,(节点编号,节点区间左端点,节点区间右端点)
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);       
        pushup(u);
    }   
    /* 线段树二分,l,r代表询问区间,与普通查询不同, 
       对于普通查询 l,r 在线段树向下查询的时候不变,向上 pushup 时将信息送给父节点再合并信息。
       对于线段数二分 l,r 查询是搜索答案,不合并子节点信息,递归解决子问题,必要的时候询问区间也要分裂
    */
    int search (int u, int d) {   // 每次都是对整个 [1,r] 区间查,没有询问区间
        if (tr[u].l == tr[u].r) return tr[u].l;       // 找到叶子节点就是答案
        int mid = (tr[u].l + tr[u].r) >> 1;     // 中点是线段树中点,其实等价
        if (tr[ls].v < d) return search(ls, d);    // 询问区间也要分开
        return search(rs, d);
    }
    void modify(int u, int pos, int v){         // 单点修改, (节点编号,查询点下标,更改值)
        if(tr[u].l == pos && tr[u].r == pos){   
           tr[u].v = v;
           return ;
        } 
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(pos <= mid) modify(ls, pos, v);     
        else modify(rs, pos, v);               
        pushup(u);          // 子节点变化,pushup往父节点更新信息
    }   
}tr;

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, q;
    cin >> n >> q;
    vector<int> a(n + 1, 0), ans(n, 0);
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        a[i] = min(a[i], n + 1);
    }
    vector<vector<PII>> v(n + 1);
    tr.build(1, 0, n + 1);
    for (int i = 0; i < q; i++) {
        int l, r;
        cin >> l >> r;
        v[r].pb({l, i});
    }
    for (int r = 1; r <= n; r++) {
        tr.modify(1, a[r], r);
        for (auto t: v[r]) 
            ans[t.second] = tr.search(1, t.first);
    }
    for (auto t: ans)
        cout << t << endl;
    return 0;
}

求矩形面积并

#include<bits/stdc++.h>
typedef long long ll;
#define arr(x) (x).begin(),(x).end()
#define pb push_back
#define endl "\n"
using namespace std;
/*----------------------------------------------------------------------------------------------------*/
/* cnt: 当前位置被覆盖了多少次,要数 cnt=0 其实不好维护
   退而求其次维护 -> 最小值、 最小值出现次数(对应线段长度总和) */

vector<int> alls;
vector<array<int, 4>> event;

int find (int x) {
    return lower_bound(arr(alls), x) - alls.begin();
}

struct SegTree{
    #define ls u << 1
    #define rs u << 1 | 1
    const static int maxn = 2e5 + 10;
    struct T{
        int l, r, v, mincnt, add;
    }tr[maxn << 3];
    void update(T& rt, int add) {
        rt.add += add;
        rt.v += add;
    }
    void pushup(int u){
        tr[u].v = min(tr[ls].v, tr[rs].v);
        if (tr[ls].v == tr[rs].v)
            tr[u].mincnt = tr[ls].mincnt + tr[rs].mincnt;
        else tr[ls].v < tr[rs].v ? tr[u].mincnt = tr[ls].mincnt : tr[u].mincnt = tr[rs].mincnt;
    }
    void pushdown(int u){
        if (tr[u].add) {
            update(tr[ls], tr[u].add);
            update(tr[rs], tr[u].add);
            tr[u].add = 0;
        }
    }
    void build(int u, int l, int r){        // 建立线段树,(节点编号,节点区间左端点,节点区间右端点)
        tr[u].l = l, tr[u].r = r;
        if(l == r){
            tr[u].v = tr[u].add = 0, tr[u].mincnt = alls[l] - alls[l - 1];
            return;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);       
        pushup(u);
    }   
    void modify(int u, int l, int r, int v){    // 区间修改
        if(tr[u].l >= l && tr[u].r <= r){       // 注意区间修改的递归出口
            update(tr[u], v);
            return ;
        }
        pushdown(u);        // 递归分裂前 pushdown  
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) modify(ls, l, r, v);
        if(r > mid) modify(rs, l, r, v);
        pushup(u);
    }
}tr;

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n;
    cin >> n;
    for (int i = 0; i < n; i++) {
        int x1, x2, y1, y2;
        cin >> x1 >> x2 >> y1 >> y2;
        alls.pb(x1), alls.pb(x2);
        event.pb({y1, x1, x2, 1});
        event.pb({y2, x1, x2, -1});
    }
    sort(arr(alls));
    alls.erase(unique(arr(alls)), alls.end());
    int m = alls.size() - 1;        // m 端点 m - 1 段
    tr.build(1, 1, m);
    int prey = 0;
    ll ans = 0;
    ll tot = tr.tr[1].mincnt;
    sort(arr(event));
    for (auto evt: event) {
        auto [y, xx1, xx2, t] = evt;
        ll len = tot;
        if (tr.tr[1].v == 0)
            len = tot - tr.tr[1].mincnt;
        ans += (evt[0] - prey) * len;
        prey = evt[0];
        int x1 = find(evt[1]) + 1, x2 = find(evt[2]);   // 数组下标(0.m-1) 映射到(1,m-1)表示线段的话,左端点需要加1
        tr.modify(1, x1, x2, evt[3]);
    }
    cout << ans << endl;
    return 0;
}

求矩形面积并(不回收标记)

#include<bits/stdc++.h>
#define ls u << 1
#define rs u << 1 | 1
#define pb push_back
typedef long long ll;
using namespace std;
const int N = 1e4 + 10;
int n;
vector<double> ys;

// 多复习扫描线的做法,扩展性不强

struct Seg{
    double x, y1, y2;
    int k;
    bool operator < (const Seg& s) const{
        return x < s.x;
    }
}seg[N << 1];

struct T{
    int l, r, cnt;
    double len;
}tr[N << 3];        // 线段树空间是线段树维护的序列的 4 倍

int find(double y){
    return lower_bound(ys.begin(), ys.end(), y) - ys.begin();
}

void pushup(int u){
    if(tr[u].cnt) tr[u].len = ys[tr[u].r + 1] - ys[tr[u].l];    // 节点有被覆盖,ys 内元素代表区间, 要表达 tr[u].r 所想要的元素要 + 1
    else if(tr[u].l != tr[u].r)     // 节点未被覆盖,但是不是叶子节点可以从子节点求出len
        tr[u].len = tr[ls].len + tr[rs].len;
    else
        tr[u].len = 0;      // 节点未被覆盖, 有可能是被更新cnt = 0, 需要将节点的 len 置为 0
}

void build(int u, int l, int r){
    if(l == r)
        tr[u] = (T){l, r, 0, 0};
    else{
        tr[u] = (T){l, r, 0 ,0};
        int mid = (tr[u].l + tr[u].r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);
    }
}

void modify(int u, int l, int r, int k){        // 扫描线不需要pushdown
    if(tr[u].l >= l && tr[u].r <= r){
        tr[u].cnt += k;
        pushup(u);      // 该节点有更新,需要 pushup 更新父节点
    }
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(l <= mid) modify(ls, l, r, k);
        if(r > mid) modify(rs, l, r, k);
        pushup(u);
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int T = 0;
    while(cin >> n, n){
        ys.clear();
        for(int i = 0, j = 0; i < n; i++){
            double x1, y1, x2, y2;
            cin >> x1 >> y1 >> x2 >> y2;
            seg[j++] = {x1, y1, y2, 1};
            seg[j++] = {x2, y1, y2, -1};
            ys.pb(y1), ys.pb(y2);       
        }
        sort(ys.begin(), ys.end());
        ys.erase(unique(ys.begin(), ys.end()), ys.end());
        sort(seg, seg + 2 * n);
        build(1, 0, ys.size() - 2);     // 下标为长度 -1, 记录元素表示区间还要-1
        double res = 0<F8>;
        for(int i = 0; i < 2 * n; i++){
            if(i > 0)
                res += tr[1].len * (seg[i].x - seg[i - 1].x);
            modify(1, find(seg[i].y1), find(seg[i].y2) - 1, seg[i].k);      // 单个元素代表以该下标值为左端点的节点, 转换到线段树维护的区间右端点要 -1
        }
        cout << "Test case #" << ++T << endl;
        cout << "Total explored area: " << fixed << setprecision(2) << res << endl << endl;         
    }
    return 0;
}

主席树(可持久化值域线段树)

const int N = 2e5 + 10;
struct Chairman_Tree {
    int node_cnt, root[N];
    struct T {
        int l, r;
        int cnt;
    }tr[N * 20];
    int get_node() {
        node_cnt++;
        tr[node_cnt].l = tr[node_cnt].r = tr[node_cnt].cnt = 0;
        return node_cnt;
    }
    void pushup(T& rt) {
        rt.cnt = tr[rt.l].cnt + tr[rt.r].cnt;
    }
    void insert(int p, int& q, int l, int r, int val) {
        q = get_node();
        tr[q] = tr[p];
        if (l == r) {
            tr[q].cnt++;
            return ;
        }
        int mid = (l + r) >> 1;
        if (val <= mid) insert(tr[p].l, tr[q].l, l, mid, val);
        else insert(tr[p].r, tr[q].r, mid + 1, r, val);
        pushup(tr[q]);
    }
    int query(int p, int q, int l, int r, int k) { // query_kmin
        if (l == r) return l;
        int mid = (l + r) >> 1;
        int lsz = tr[tr[q].l].cnt - tr[tr[p].l].cnt;
        if (lsz >= k) return query(tr[p].l, tr[q].l, l, mid, k);
        return query(tr[p].r, tr[q].r, mid + 1, r, k - lsz);
    }
}ctr;

带修主席树

// Luogu P2617
#include<bits/stdc++.h>
typedef long long ll;
typedef unsigned long long ull;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
typedef double db;
#define arr(x) (x).begin(),(x).end()
#define get_sz(v) (int)v.size()
#define fi first
#define se second
#define pb push_back
#define endl "\n"
using namespace std;
template <typename T> std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) { out << "["; bool first = true; for (auto &&e : v) { if (first) { first = false;} else {out << ", ";} out << e; } return out << "]"; }
template <typename A, typename B> std::ostream &operator<<(std::ostream &out, const std::pair<A, B> &v) {  return out << "(" << v.first << ", " << v.second << ")"; }
template <typename K> std::ostream &operato<<(std::ostream &out, const std::set<K> &s) {  out << "{"; bool first = true; for (auto &&k : s) { if (first) { first = false; } else { out << ", "; } out << k; } return out << "}"; }
template <typename K, typename V> std::ostream &operator<<(std::ostream &out, const std::map<K, V> &m) { out << "{"; bool first = true; for (auto &&[k, v] : m) { if (first) { first = false; } else { out << ", "; } out << k << ": " << v; } return out << "}"; }
template <class T> vector<vector<T>> Vector(int n, int m) { return vector<vector<T>> (n, vector<T> (m, 0)); }
template <class T> vector<vector<vector<int>>> Vector(int i, int j, int k) { return vector<vector<vector<T>>> (i, vector<vector<T>>(j, vector<T>(k, 0))); }
/*----------------------------------------------------------------------------------------------------*/
const int N = 1e6 + 10;


int cntl, cntr, tmpl[N], tmpr[N];
int a[N], n, m, INF;

struct Chairman_Tree {
    int node_cnt, root[N];
    struct T {
        int l, r;
        int cnt;
    }tr[N * 20];
    int get_node() {
        node_cnt++;
        tr[node_cnt].l = tr[node_cnt].r = tr[node_cnt].cnt = 0;
        return node_cnt;
    }
    void pushup(T& rt) {
        rt.cnt = tr[rt.l].cnt + tr[rt.r].cnt;
    }
    void insert(int p, int& q, int l, int r, int val) {
        q = get_node();
        tr[q] = tr[p];
        if (l == r) {
            tr[q].cnt++;
            return ;
        }
        int mid = (l + r) >> 1;
        if (val <= mid) insert(tr[p].l, tr[q].l, l, mid, val);
        else insert(tr[p].r, tr[q].r, mid + 1, r, val);
        pushup(tr[q]);
    }
    void modify(int& u, int l, int r, int pos, int v) {
        if (!u) u = get_node();
        if (l == r)  {
            tr[u].cnt += v;
            return ;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) modify(tr[u].l, l, mid, pos, v);
        else modify(tr[u].r, mid + 1, r, pos, v);
        pushup(tr[u]);
    }
    int query(int l, int r, int k) { // query_kmin
        if (l == r) {
            return l;
        }
        int mid = (l + r) >> 1;
        ll sum = 0;
        for (int i = 1; i <= cntr; i++) sum += tr[tr[tmpr[i]].l].cnt;
        for (int i = 1; i <= cntl; i++) sum -= tr[tr[tmpl[i]].l].cnt;
        if (sum >= k) {
            for (int i = 1; i <= cntr; i++) tmpr[i] = tr[tmpr[i]].l;
            for (int i = 1; i <= cntl; i++) tmpl[i] = tr[tmpl[i]].l;
            return query(l, mid, k);
        }
        for (int i = 1; i <= cntr; i++) tmpr[i] = tr[tmpr[i]].r;
        for (int i = 1; i <= cntl; i++) tmpl[i] = tr[tmpl[i]].r;
        return query(mid + 1, r, k - sum);
    }
}ctr;


template<class T>
struct BIT {
    const static int maxn = 2e5 + 10;
	int n;
	inline int lowbit(int x) { return x & (-x); }
    void init(int _n) {
        n = _n;
    }
	void add(int x, int v) {
		for (int i = x; i <= n; i += lowbit(i)) ctr.modify(ctr.root[i], 0, INF, a[x], v);
	}
};
BIT<int> bt;


int Query(int ql, int qr, int k) {
    cntl = cntr = 0;
    for (int i = qr; i; i -= (i & -i)) tmpr[++cntr] = ctr.root[i];
    for (int i = ql - 1; i; i -= (i & -i)) tmpl[++cntl] = ctr.root[i];
    return ctr.query(0, INF, k);
}

vector<array<int, 4>> q;
vector<int> alls;

int find(int x) {
    return lower_bound(arr(alls), x) - alls.begin();
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        alls.pb(a[i]);
    }
    bt.init(n);
    for (int i = 0; i < m; i++) {
        string s;
        int l, r, k, x, y;
        cin >> s;
        int t;
        if (s == "Q") {
            t = 0;
            cin >> l >> r >> k;
            q.pb({t, l, r, k});
        }
        else {
            t = 1;
            cin >> x >> y;
            q.pb({t, x, y, 0});
            alls.pb(y);
        }
    }
    sort(arr(alls));
    alls.erase(unique(arr(alls)), alls.end());
    INF = alls.size();
    for (int i = 1; i <= n; i++) {
        a[i] = find(a[i]);
        bt.add(i, 1);
    }
    for (int i = 0; i < m; i++) {
        if (!q[i][0]) {
            auto [_, l, r, k] = q[i];
            cout << alls[Query(l, r, k)] << endl;
        }
        else {
            auto [_, x, y, __] = q[i];
            bt.add(x, -1);
            a[x] = find(y);
            bt.add(x, 1);
        }
    }
    return 0;
}

树套树

线段树套平衡树

原理: 线段树维护整个区间,平衡树维护线段树节点区间内的有序序列

  • 仅涉及单点修改、区间前驱后继时,可以在线段树节点内设立 multiset 来替代平衡树

例题:树套树

  1. l r x,查询整数 \(x\) 在区间 \([l,r]\) 内的排名。
  2. l r k,查询区间 \([l,r]\) 内排名为 \(k\) 的值。
  3. pos x,将 \(pos\) 位置的数修改为 \(x\)
  4. l r x,查询整数 \(x\) 在区间 \([l,r]\) 内的前驱(前驱定义为小于 \(x\),且最大的数)。
  5. l r x, 查询整数 \(x\) 在区间 \([l,r]\) 内的后继(后继定义为大于 \(x\),且最小的数)。
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
const int N = 5e4 + 10, M = 1500010, INF = 1e9;     
// 空间计算:线段树节点4*n,每个节点固定2个哨兵等于 4 * n * 2,总共logn层每层长度为n
// 平衡树节点n * logn,最后空间 8 * n + logn * n <=> 1200000,开大点就1500010

struct Splay{
    int v, p, s[2];
    int size;
    void init(int v_, int p_){
        v = v_, p = p_;
        size = 1;
    }
}spl[M];

int idx, w[N], n, m;

void pushup(int u){
    spl[u].size = spl[spl[u].s[0]].size + spl[spl[u].s[1]].size + 1;
}

void rotate(int x){
    int y = spl[x].p, z = spl[y].p;
    int k = spl[y].s[1] == x;
    spl[z].s[spl[z].s[1] == y] = x, spl[x].p = z;
    spl[y].s[k] = spl[x].s[k ^ 1], spl[spl[x].s[k ^ 1]].p = y;
    spl[x].s[k ^ 1] = y, spl[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x, int k, int& root){
    while(spl[x].p != k){
        int y = spl[x].p, z = spl[y].p;
        if(z != k){
            if(spl[y].s[0] == x != spl[z].s[0] == y) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
    if(!k) root = x;
}

int get_less(int v, int root){        // 找到比数值v小的数有多少个
    int u = root, res = 0;
    while(u){
        if(v > spl[u].v) res += spl[spl[u].s[0]].size + 1, u = spl[u].s[1];
        else u = spl[u].s[0];
    }
    return res;
}

int get_pre(int v, int& root){      // 获取v的前缀
    int u = root, res = -INF;
    while(u){
        if(v > spl[u].v) res = max(res, spl[u].v), u = spl[u].s[1];
        else u = spl[u].s[0];
    }
    return res;
}

int get_next(int v, int root){     // 获取v的后缀
    int u = root, res = INF;
    while(u){
        if(v < spl[u].v) res = min(res, spl[u].v), u = spl[u].s[0];
        else u = spl[u].s[1];
    }
    return res;
}

void insert(int v, int& root){
    int u = root, p = 0;
    while(u) p = u, u = spl[u].s[v > spl[u].v];
    u = ++idx;
    if(p) spl[p].s[v > spl[p].v] = u;
    spl[u].init(v, p);
    splay(u, 0, root);
}

struct T{
    int l, r, u;
}tr[N << 2];

void update(int x, int y, int& root){       // 在splay中找到值为x的点并更新为y
    int u = root;
    while(u){
        if(spl[u].v == x) break;
        else if(spl[u].v < x) u = spl[u].s[1];
        else u = spl[u].s[0];
    }
    splay(u, 0, root);          // 节点x旋转到根
    int l = spl[u].s[0], r = spl[u].s[1];
    while(spl[l].s[1]) l = spl[l].s[1];
    while(spl[r].s[0]) r = spl[r].s[0];
    splay(l, 0, root), splay(r, l, root);       // 找到x在序列中前一个节点和后一个节点,并splay到根
    spl[r].s[0] = 0;        // 删除节点x
    pushup(r), pushup(l);   // 删除(修改)后要pushup(r), pushup(l)
    insert(y, root);        // 将y插入,insert会嵌套splay嵌套rotate带有pushup
}

void change(int u, int pos, int x){
    update(w[pos], x, tr[u].u);     // 只要线段树节点维护的区间包含pos就需要update
    if(tr[u].l == tr[u].r) return;
    int mid = (tr[u].l + tr[u].r) >> 1;
    if(pos <= mid) change(ls, pos, x);
    else change(rs, pos, x);
}

int query(int u, int l, int r, int x){     // 查询x的排名
    if(tr[u].l >= l && tr[u].r <= r){
        return get_less(x, tr[u].u) - 1;    // 查找的是小于x的数有多少个,因为包含哨兵需要-1
    }
    int mid = (tr[u].l + tr[u].r) >> 1, res = 0;
    if(l <= mid) res += query(ls, l, r, x);     // 整个区间小于x的数=左右区间小于x的数之和
    if(r > mid) res += query(rs, l, r, x);
    return res;
}

int query_pre(int u, int l, int r, int x){
    if(tr[u].l >= l && tr[u].r <= r)
        return get_pre(x, tr[u].u);
    int mid = (tr[u].l + tr[u].r) >> 1, res = -INF;
    if(l <= mid) res = max(res, query_pre(ls, l, r, x));    // 在区间中小于x的最大数是,左右区间小于x的数的max
    if(r > mid) res = max(res, query_pre(rs, l, r, x));
    return res;
}

int query_next(int u, int l, int r, int x){
    if(tr[u].l >= l && tr[u].r <= r)
        return get_next(x, tr[u].u);
    int mid = (tr[u].l + tr[u].r) >> 1, res = INF;
    if(l <= mid) res = min(res, query_next(ls, l, r, x));       // 与上同理
    if(r > mid) res = min(res, query_next(rs, l, r, x));
    return res;
}

void build(int u, int l, int r){
    tr[u].l = l, tr[u].r = r;           // tr[u].u会随着函数引用而改变
    insert(-INF, tr[u].u), insert(INF, tr[u].u);        // 每个平衡树都要加入哨兵
    for(int i = l; i <= r; i++) insert(w[i], tr[u].u);
    if(tr[u].l == tr[u].r) return;
    int mid = (tr[u].l + tr[u].r) >> 1;
    build(ls, l, mid), build(rs, mid + 1, r);
}

int main(){
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> w[i];
    build(1, 1, n);
    for(int i = 1; i <= m; i++){
        int op, l, r, x, pos, k;
        cin >> op;
        if(op == 1){    // x 在[l,r]中排名
            cin >> l >> r >> x;
            cout << query(1, l, r, x) + 1 << endl;      // query查询小于x的个数,成为x的排名需要+1
        }
        else if(op == 2){       // [l,r]中排名k的值
            cin >> l >> r >> k;
            int a = 0, b = 1e8;
            while(a < b){               // 由于排名与值具有单调性,可以二分大小来找到排名为k的数
                int mid = (a + b + 1) >> 1;
                if(query(1, l, r, mid) + 1 <= k) a = mid;   // query查询小于x的个数,成为x的排名需要+1
                else b = mid - 1;
            }
            cout << b << endl;
        }
        else if(op == 3){
            cin >> pos >> x;
            change(1, pos, x);
            w[pos] = x;
        }
        else if(op == 4){
            cin >> l >> r >> x;
            cout << query_pre(1, l, r, x) << endl;
        }
        else{
            cin >> l >> r >> x;
            cout << query_next(1, l, r, x) << endl;
        }
    }
    return 0;
}

线段树套线段树

例题:查询第 \(K\) 大数

  1. 离线处理存储所有要插入的数
  2. 外层权值线段树维护权值,内层普通线段树维护权值范围内下标区间,支持区间加和求和
  3. 区间加使用标记永久化优化常数
#define ls u << 1
#define rs u << 1 | 1
#define pb push_back
using namespace std;
// 离线处理 + 权值线段树套动态开点线段树维护区间信息 + 标记永久化
// 也可以用树状数组+线段树+整体二分求解
const int N = 5e4 + 10, M = 1e7 + 3e6 + 10;

struct Q{
    int op, l, r, x;
}q[N];

struct T1{
    int l, r, u;
}tr1[N << 2];

int n, m, idx;
vector<int> alls;

struct T2{
    int l, r, add, sum;     // l, r 是左右儿子
}tr2[M];

int inter(int l, int r, int ql, int qr){
    return min(qr, r) - max(l, ql) + 1;
}

void update(int u, int l, int r, int ql, int qr){       // l, r 是tr2的区间左右端点
    tr2[u].sum += inter(l, r, ql, qr);
    if(l >= ql && r <= qr){
        tr2[u].add ++;
        return ;
    }
    int mid = (l + r) >> 1;
    if(ql <= mid){
        if(!tr2[u].l) tr2[u].l = ++ idx;
        update(tr2[u].l, l, mid, ql, qr);
    }
    if(qr > mid){
        if(!tr2[u].r) tr2[u].r = ++ idx;
        update(tr2[u].r, mid + 1, r, ql, qr);
    }
}

int get_sum(int u, int l, int r, int ql, int qr, int add){
    int res = 0;
    if(l >= ql && r <= qr)
        return tr2[u].sum + (r - l + 1) * add;
    add += tr2[u].add;          // 节点本身标记不变,传递的标记要累加
    int mid = (l + r) >> 1;
    if(ql <= mid){
        if(!tr2[u].l) res += inter(l, mid, ql, qr) * add;
        else res += get_sum(tr2[u].l, l, mid, ql, qr, add);
    }
    if(qr > mid){
        if(!tr2[u].r) res += inter(mid + 1, r, ql, qr) * add;
        else res += get_sum(tr2[u].r, mid + 1, r, ql, qr, add);
    }
    return res;
}

int get(int x){
    return lower_bound(alls.begin(), alls.end(), x) - alls.begin();
}

int query(int u, int l, int r, int x){
    if(tr1[u].l == tr1[u].r) return tr1[u].l;
    int mid = (tr1[u].l + tr1[u].r) >> 1;
    int k = get_sum(tr1[rs].u, 1, n, l, r, 0);      // 获取右边子树有多少个数
    if(k >= x) return query(rs, l, r, x);
    else return query(ls, l, r, x - k);
}

void change(int u, int l, int r, int x){
    update(tr1[u].u, 1, n, l, r);
    if(tr1[u].l == tr1[u].r) return ;
    int mid = (tr1[u].l + tr1[u].r) >> 1;
    if(x <= mid) change(ls, l, r, x);
    else change(rs, l, r, x);
}

void build(int u, int l, int r){
    tr1[u].l = l, tr1[u].r = r, tr1[u].u = ++ idx;
    if(tr1[u].l == tr1[u].r) return;
    int mid = (tr1[u].l + tr1[u].r) >> 1;
    build(ls, l, mid), build(rs, mid + 1, r);
}

int main(){
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= m; i++){
        cin >> q[i].op >> q[i].l >> q[i].r >> q[i].x;
        if(q[i].op == 1)
            alls.pb(q[i].x);
    }
    sort(alls.begin(), alls.end());
    alls.erase(unique(alls.begin(), alls.end()), alls.end());
    build(1, 0, alls.size() - 1);
    for(int i = 1; i <= m; i++){
        if(q[i].op == 1) change(1, q[i].l, q[i].r, get(q[i].x));
        else cout << alls[query(1, q[i].l, q[i].r, q[i].x)] << endl;
    }
    return 0;
}

莫队算法

基础莫队

已知可以维护的区间信息:

  1. 各个数字出现次数之和、平方和、立方和、等等。
  2. 使用 bitset 保存区间内是否有两数 和 或 差 为 \(x\) ,区间内乘积可采取暴力枚举因素,关于商出题人说可做。
  3. 区间众数的出现次数(要开一个nums数组,nums[i] 记录出现 \(i\) 次的数字有多少个,也可以用回滚莫队实现
  4. 区间内子区间异或值等于 \(k\) 的个数。(采用异或前缀和,注意询问的 \(l\) 需要减 \(1\) ,达到前缀和两点相减)

例题 : HH的项链

维护区间内不同数个数

#include<bits/stdc++.h>
typedef long long ll;
#define endl "\n"
using namespace std;
const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int cnt[S], ans[M], n, m, len, w[N], res;

#define get(x) (x / len)

struct Q{
    int l, r, id;
    bool operator < (const Q & q)const{     // 第一关键字是块号,第二关键字是右端点大小
        if(get(l) == get(q.l))  return r < q.r;     
        return get(l) < get(q.l);
    }
}q[M];

bool cmp(const Q& a, const Q& b){       // 按照奇偶性排序,玄学优化可能快一倍
    return get(a.l) ^ get(b.l) ? get(a.l) < get(b.l) : ((get(a.l) & 1) ? a.r < b.r : a.r > b.r);
}

void add(int x){
    if(!cnt[x]) res++;
    cnt[x] ++;
}

void del(int x){
    cnt[x] --;
    if(!cnt[x]) res--;
}


int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i++)
        scanf("%d", &w[i]);
    scanf("%d", &m);
    len = max(1, (int)sqrt((double) n *n / m));    // m 过大尝试 块长度为 sqrt(n * n / m)
    for(int i = 1; i <= m; i++){
        int l, r;
        scanf("%d%d", &l, &r);
        q[i] = {l, r, i};
    }
    sort(q + 1, q + 1 + m, cmp);
    for(int i = 1, l = 1, r = 0; i <= m; i++){
        auto [ql, qr, id] = q[i];
        while(l > ql) add(w[--l]);      // 加答案++ x
        while(r < qr) add(w[++r]);
        while(l < ql) del(w[l++]);      // 减答案x ++
        while(r > qr) del(w[r--]);
        ans[id] = res;
    }
    for(int i = 1; i <= m; i++)
        printf("%d\n", ans[i]);
    return 0;
} 

带修莫队

在基础莫队的基础上支持单点修改

核心:加入时间戳这一维度

例题 : 数颜色(\(n\)\(m\) 同阶取块大小为 \(n^{\frac{2}{3}}\),复杂度为 \(n^{\frac{5}{3}}\)

#include<bits/stdc++.h>
typedef long long ll;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
//#pragma GCC optimize(3,"Ofast","inline")
#define x first
#define y second
#define pb push_back
#define mkp make_pair
#define endl "\n"
using namespace std;
const int N = 10010, S = 1e6 + 10;

// 带修莫队,将操作时刻进行时间戳,多加一维。 复杂度O(n^{4/3} * t{1/3}), 区间长度为 (n*t)^{1/3};

int cnt[S], ans[N], w[N], cntc, cntq, n, m, len, res;

struct Q{
    int l, r, t, id;
}q[N];

struct C{
    int p, v;
}c[N];

int get(int x){
    return x / len;
}

bool cmp(const Q& a, const Q& b){
    int al = get(a.l), ar = get(a.r);
    int bl = get(b.l), br = get(b.r);
    if(al != bl) return al < bl;
    if(ar != br) return ar < br;
    return a.t < b.t;
}

void add(int x){
    if(!cnt[x]) res ++;
    cnt[x] ++;
}

void del(int x){
    cnt[x] --;
    if(!cnt[x]) res--;
}

int main(){
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++)
        cin >> w[i];
    for(int i = 1; i <= m; i++){
        char op;
        int l, r;
        cin >> op >> l >> r;
        if(op == 'Q') cntq++, q[cntq] = {l, r, cntc, cntq};
        else c[++cntc] = {l, r};
    }
    len = cbrt((double)n * max(1 , cntq)) + 1;
    sort(q + 1, q + 1 + cntq, cmp);
    for(int i = 1, l = 1, r = 0, now = 0; i <= cntq; i++){
        auto [ql, qr, t, id] = q[i];
        // 先处理前两维
        while(l < ql) del(w[l++]);
        while(l > ql) add(w[--l]);
        while(r < qr) add(w[++r]);
        while(r > qr) del(w[r--]);
        while(now < t){
            now++;  // 先加
            if(c[now].p >= l && c[now].p <= r){ // 更新点在区间内,要更新值的出现次数,-1 +1 即可
                del(w[c[now].p]);
                add(c[now].v);
            }
            swap(w[c[now].p], c[now].v);        // 将对应位置上的数更换掉。
        }
        while(now > t){
            if(c[now].p >= l && c[now].p <= r){
                del(w[c[now].p]);
                add(c[now].v);
            }
            swap(w[c[now].p], c[now].v);
            now--;  // 后减
        }   
        ans[id] = res;
    }
    for(int i = 1; i <= cntq; i++)
        cout << ans[i] << endl;
    return 0;
}

回滚莫队

  1. 增加或删除其中一种操作好做另一种不好做,可以尝试回滚莫队。

做过的可维护信息:

  1. 维护查询区间 \(max{w_i\times cnt_{w_i}}\)
  2. 维护查询区间 \(相同数字距离最大值\)

维护 \((权值\times 出现次数) 的最大值\)

#include<bits/stdc++.h>
typedef long long ll;
#define endl "\n"
using namespace std;
const int N = 1e5 + 10;
int w[N], len, n, m, cnt[N];
ll ans[N];
vector<int> alls;

// 回滚莫队
// 以左端点块号分类,对于相同左端点区间,右端点在块内为一类,右端点在另外块为一类
// 第一类直接暴力
// 第二类由于右端点递增可以很好维护,左端点 l 距离 ql 最多sqrt(n),每次暴力来回更新
// 复杂度(n * sqrt(n)),块长度为 sqrt(n)

int get(int x){
    return x / len;
}

struct Q{
    int l, r, id;
    bool operator < (const Q& q)const{
        int al = get(l), bl = get(q.l);
        if(al != bl) return al < bl;
        return r < q.r;
    }
}q[N];

void add(int x, ll& res){
    cnt[x]++;
    res = max(res, 1ll * alls[x] * cnt[x]);
}

int main(){
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    len = sqrt(n);
    for(int i = 1; i <= n; i++){
        cin >> w[i];
        alls.pb(w[i]);
    }
    sort(alls.begin(), alls.end());
    alls.erase(unique(alls.begin(), alls.end()), alls.end());
    for(int i = 1; i <= n; i++){
        w[i] = lower_bound(alls.begin(), alls.end(), w[i]) - alls.begin();
    }

    for(int i = 1; i <= m; i++){
        int l, r;
        cin >> l >> r;
        q[i] = {l, r, i};
    }
    sort(q + 1, q + 1 + m);
    int x = 1;
    while(x <= m){
        int y = x;
        while(y <= m && get(q[x].l) == get(q[y].l)) y++;
        // 块内
        int right = get(q[x].l) * len + len - 1;
        while(x < y && q[x].r <= right){
            ll res = 0;
            auto [l, r, id] = q[x++];
            for(int i = l; i <= r; i++) add(w[i], res);
            for(int i = l; i <= r; i++) cnt[w[i]]--;
            ans[id] = res;
        }
        ll res = 0;
        int l = right + 1, r = right;
        // 块间
        while(x < y){
            auto [ql, qr, id] = q[x++];
            while(r < qr) add(w[++r], res);
            ll backup = res;        // 信息备份,可能不止 res 一个信息需要备份
            while(l > ql) add(w[--l], res);
            while(l <= right) cnt[w[l++]]--;
            ans[id] = res;
            res = backup;
        }
        memset(cnt, 0, sizeof cnt);
    }
    for(int i = 1; i <= m; i++)
        cout << ans[i] << endl;
    return 0;
}

树上莫队

核心:利用 dfn 将树上问题->序列问题。

#include<bits/stdc++.h>
typedef long long ll;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
//#pragma GCC optimize(3,"Ofast","inline")
#define x first
#define y second
#define pb push_back
#define mkp make_pair
#define endl "\n"
using namespace std;
const int N = 4e4 + 10, M = 1e5 + 10, Lg = 20;
int len, n, m, top;
int h[N], e[2 * N], ne[2 * N], idx, w[N], cnt[N], st[N], seq[N * 2], in[N], out[N], ans[M];
int dep[N], fa[N][Lg];
vector<int> alls;

// 树上莫队,统计两点间路径不同权值数
// 1. 离散化权值、建图
// 2. dfs 得到 欧拉序列 和 LCA,将 树上问题 转换为 序列问题。
// 3. 根据询问中 a 和 b 是否是一条链进行分类讨论
//   3.1. 是一条链,询问序列为 [in[a], in[b]],序列中出现一次的点是路径上的点。
//	 3.2. 不是一条链,询问序列为 [out[a], in[b]], 并且记录两者的 LCA,加上序列中出现一次的点 就是路径上的点
// 4. 进行莫队算法,分块、排序。
// 5. 按照莫队算法流程来处理每个询问,cnt[v]记录权值v得出现次数,st[x]记录x节点的出现次数。
// 6. st[x] 每次 ^= 1 可以将删除和增加操作合并,然后依次判断即可,最后输出答案

void add_edge(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

int get(int x){
	return x / len;
}

struct Q{
	int l, r, id, p;
	bool operator < (const Q& q) const{
		int al = get(l), bl = get(q.l);
		if(al != bl) return al < bl;
		return r < q.r;
	}
}q[M];

void dfs(int u, int p){
	in[u] = ++top;
	seq[top] = u;
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == p) continue;
		dep[j] = dep[u] + 1;
		fa[j][0] = u;
		dfs(j, u);
	}
	out[u] = ++top;
	seq[top] = u;
}

void init(){
	for(int i = 1; i < Lg; i++)		// 先循环跳的次数
		for(int j = 1; j <= n; j++)		// 再循环节点个数
			if(fa[j][i - 1])
				fa[j][i] = fa[fa[j][i - 1]][i - 1];
}

int lca(int a, int b){		// 求点 a 和 点 b 的最近公共祖先
	if(dep[a] < dep[b])
		swap(a, b);
	int d = dep[a] - dep[b];		// 深度大的是 a
	for(int i = 0; i < Lg && d; i++, d /= 2){
		if(d & 1)
			a = fa[a][i];
	}
	if(a == b) return a;
	for(int i = Lg - 1; i >= 0; i--)
		if(fa[a][i] != fa[b][i])
			a = fa[a][i], b = fa[b][i];
	return fa[a][0];
}

void add(int x, int& res){
	st[x] ^= 1;
	if(st[x]){
		if(!cnt[w[x]]) res++;
		cnt[w[x]]++;
	}
	else{
		cnt[w[x]]--;
		if(!cnt[w[x]]) res--;
	}
}


int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    memset(h, -1, sizeof h);
    cin >> n >> m;
    for(int i = 1; i <= n; i++)
    	cin >> w[i], alls.pb(w[i]);
    sort(alls.begin(), alls.end());
    alls.erase(unique(alls.begin(), alls.end()), alls.end());
    for(int i = 1; i <= n; i++)
    	w[i] = lower_bound(alls.begin(), alls.end(), w[i]) - alls.begin();
    for(int i = 1; i < n; i++){
    	int u, v;
    	cin >> u >> v;
    	add_edge(u, v), add_edge(v, u);
    }
    dfs(1, -1);
    init();
    len = sqrt(top);
    for(int i = 1; i <= m; i++){
    	int a, b, p;
    	cin >> a >> b;
    	if(in[a] > in[b]) swap(a, b);
    	p = lca(a, b);
    	if(p == a)
    		q[i] = {in[a], in[b], i};
    	else
    		q[i] = {out[a], in[b], i, p};
    }
    sort(q + 1, q + 1 + m);
    for(int i = 1, l = 1, r = 0, res = 0; i <= m; i++){
    	auto [ql, qr, id, p] = q[i];
    	while(l > ql) add(seq[--l], res);
    	while(l < ql) add(seq[l++], res);
    	while(r < qr) add(seq[++r], res);
    	while(r > qr) add(seq[r--], res);
    	if(q[i].p){
    		add(q[i].p, res);
    		ans[id] = res;
    		add(q[i].p, res);
    	}
    	else ans[id] = res;
    }
    for(int i = 1; i <= m; i++)
    	cout << ans[i] << endl;
    return 0;
}

重链剖分

定义

重子节点:子节点中子树最大的子节点,多个取其一,没有子节点就没有。
轻子节点:剩余所有子节点,重子节点以外节点。
重边:节点到重子节点的边
重链:若干首尾相连的重边,把落单节点也酸橙重链,那么树被剖分成若干条重链

性质

  1. 树上每个节点都属于且仅属于一条重链
  2. 重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)
  3. 所有的重链将整棵树 完全剖分
  4. 在剖分时 重边优先遍历,最后树的 DFN 序上,重链内的 DFN 序是连续的。按 DFN 排序后的序列即为剖分后的链。
  5. 一颗子树内 DFN 序连续
  6. 向下经过一条 轻边,子树大小至少除二。
  7. 因此,对于树上的任意一条路径,把它拆分成从 \(LCA\) 分别向两边往下走,分别最多走 \(O(logn)\) 次,因此,树上的每条路径都可以被拆分成不超过 \(O(logn)\) 条重链。
  • 重链剖分可以将树上的任意一条路径划分成不超过 \(O(logn)\) 条连续的链,
    每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点)。

  • 重链剖分还能保证划分出的每条链上的节点 DFS 序连续,
    因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。

    • 修改 树上两点之间的路径上所有点的值。
    • 查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)
int fa[N], dep[N], sz[N], hson[N], top[N], dfn[N], rk[N], tot;
void dfs1(int u){        // 获取 fa[u],dep[u],sz[u],hson[u]
    hson[u] = -1;       // dep[root] 务必为0
    sz[u] = 1;
    for (auto v: edge[u]) {
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs1(v);
        sz[u] += sz[v];
        if(hson[u] == -1 || sz[v] > sz[hson[u]]) hson[u] = v;
    }
}

void dfs2(int u, int tp){
    top[u] = tp;
    dfn[u] = ++ tot;
    rk[tot] = u;
    if(hson[u] == -1) return;   // 没有重节点 等价 是叶子节点
    dfs2(hson[u], tp);   // 优先对重儿子dfs,保证同一重链点 DFS 序连续,继承链顶
    for (auto v: edge[u]) {
        if (v == hson[u] || v == fa[u]) continue;
        dfs2(v, v);  // 不是重节点,top[v] = v;
    }
}

void path_modify(int u, int v, int val) {
   while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        tr.modify(1, dfn[top[u]], dfn[u], val);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    tr.modify(1, dfn[u], dfn[v], val);      // dfn Not dep
    return;
}

int path_query(int u, int v) {
    int res = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        res += tr.query(1, dfn[top[u]], dfn[u]);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    res += tr.query(1, dfn[u], dfn[v]);
    return res;
} 

应用

路径上维护

例:维护两点路径权值和

  • 链上的 DFS 序是连续的,可以使用线段树、树状数组维护。
  • 每次选择深度较大的链往上跳,直到两点在同一条链上。
  • 同样的跳链结构适用于维护、统计路径上的其他信息。
int path_sum(int u, int v){
    int res = 0;
    while(top[u] != top[v]){
        if(dep[top[u]] < dep[top[v]])
            swap(u, v);
        res += ... // res += u 到 top[u] 的路径和
        u = fa[top[u]];
    }
    res += ... // res += u 到 v 上的路径和
}

子树维护

由于子树 DFS 序是连续的,很容易对整个子树进行操作。记录 in[u], out[u] 即可。

求最近公共祖先(LCA)

  • 不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 LCA。
  • 向上跳重链时需要先跳所在重链顶端深度较大的那个
int lca(int u, int v){
    while(top[u] != top[v]){
        if(dep[top[u]] > dep[top[v]])
            u = fa[top[u]];
        else
            v = fa[top[v]];
    }
    return dep[u] > dep[v] ? v : u;
}

Dsu On Tree(树上启发式合并)

// init中,我们将dfs序求出来,节点i的重儿子用hs[i]表示,没有的话就是-1
// 当然这里可以初始化些其他的东西
void dfs_init(int u, int fa){
    l[u] = ++tot;
    id[tot] = u;
    sz[u] = 1;
    hs[u] = -1;
    for(int i = h[u]; i != - 1; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        dfs_init(j, u);
        sz[u] += sz[j];
        if(hs[u] == -1 || sz[j] > sz[hs[u]]) hs[u] = j;
    }
    r[u] = tot;
}


int cnt[N], maxcnt;
long long maxnsum, sumcnt;
// keep表示这个节点为根的子树的信息是不是要保留,重链保留,轻链不保留
void dfs_solve(int u, int fa, bool keep){
    // 轻链的话,我们直接递归求解就可以了,最后的信息不用保留
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j ==  fa || j == hs[u]) continue;
        dfs_solve(j, u, false);
    }
    // 重链的话,我们同样递归求解,求解得到的信息需要保留
    if(hs[u] != -1){
        dfs_solve(hs[u], u, true);
    }

    auto add = [&](int x){
    	for(int i = 0; i <= 18; i ++){
    		f[a[x]][i][x >> i & 1] ++;
    	}
    };

    auto query = [&](int x){
    	for(int i = 0; i <= 18; i ++){
    		ans += f[a[u] ^ a[x]][i][(~x) >> i & 1] * ( 1ll << i);
    	}        
    };    
    // 遍历轻链的每个节点,将其加入,关键是写好加入维护的信息有什么影响,对其更新就好了
    // 这里需要注意的是add(id[x]), 并不是add(x)
    for(int i = h[u]; i != -1; i = ne[i]){
        int j = e[i];
        if(j == fa || j == hs[u]) continue;
        for(int x = l[j]; x <= r[j]; x ++){
        	query(id[x]);
       }
        for(int x = l[j]; x <= r[j]; x ++){
            add(id[x]);
        }
    }
    // 将根节点加入,并且这里也要记得对维护的信息进行更新
    add(u);
    // query(u);
    // ans[u] = sumcnt;
    auto del = [&](int x){
    	for(int i = 0; i <= 18; i ++){
    		f[a[x]][i][x >> i & 1] --;
    	}
    };
    // 如果信息不需要保留的话,把信息清空就可以了,这里其实是比较简单的
    if(!keep){
        maxcnt = 0, sumcnt = 0;
        for(int x = l[u]; x <= r[u]; x ++) del(id[x]);
    }

}

虚树

struct VirutalTree {
    /* 需要预处理 dfs 序、LCA */
    vector<int> edge[N], nodes;
    int stk[N], top;
    void add_edge(int a, int b) { edge[a].pb(b); }
    void build() {      // 读入 k 个数,作为关键点插入
        int k;
        scanf("%d", &k);
        nodes.clear(), nodes.resize(k);
        for (auto &t : nodes)
            re(t), query[t] = true;
        sort(ALL(nodes), [&](int a, int b){
            return dfn[a] < dfn[b];
        });
        stk[top = 1] = 1, edge[1].clear();
        for (const auto u: nodes) 
            insert(u);
        for (int i = 1; i < top; i++)
            add_edge(stk[i], stk[i + 1]);
    }
    /* 入栈清边 */
    void insert(int u) {
        if (h[u] == 1) return;      // 先插入了 1,避免重复
        int p = lca(u, stk[top]);
        if (p != stk[top]) {
            whie (dfn[p] < dfn[stk[top - 1]])
                add_edge(stk[top - 1], stk[top]), top --;
            if (dfn[p] == dfn[stk[top - 1]])
                add_edge(p, stk[top--]);
            else 
                edge[p].clear(), add_edge(p, stk[top]), stk[top] = p;
        }
        edge[u].clear(), stk[++top] = u;
    }
    // extension
} vt;

字符串

定义及性质

Border

字符串 \(S\) 的同长度前缀和后缀完全相同, \(Prefix[i] = Suffix[i] <=> S[1,p] == S[|S|-p+1, |S|]\) ,则称为 \(Border\) ,字符串本身可以是自己的 \(Border\) 根据情况判断。

  • \(Prefix[i]\) 的 Border 长度减 \(1\)\(Prefix[i - 1]\) 的 Border 长度,反之不一定成立, 需要检验后一个字符是否相等。

周期

  • 对于字符串 \(S\) 和正整数 \(p\) ,如果有 \(S[i] = S[i - p]\) ,对于 \(p < i \leq |S|\) 成立,则 \(p\) 为字符串的一个周期。
  • 当然,\(p=|S|\) 一定是 \(S\) 的周期

循环节

  • \(p\) 是字符串 \(S\) 的周期,满足 \(p \;|\; |S|\) ,则 \(p\)\(S\) 的一个循环节。
  • 当然,\(p = |S|\)\(S\) 的循环节

性质

  • \(p\)\(S\) 的周期等价于 \(|S| - p\)\(S\) 的 Border。
    • 即字符串周期性质等价于 Border 性质,注意 Border不具有二分性
  • Border 具有传递性,即 Border 的 Border 也是字符串的 Border。
    • 即求字符串的所有 Border 等价于求所有前缀的最大Border。

Border树

对于字符串 \(S\)\(n = |S|\),它的 Border 树 (next 树) 共有 \(n+1\) 个节点:\(0, 1, 2, 3,..,n\)\(0\) 是这颗有向树的根。对于其他节点父节点为 \(ne[i]\)

性质

  • 每个前缀 \(Prefix[i]\) 的所有 Border ,就是节点 \(i\) 到根的链。
  • 哪些前缀有长度为 \(x\) 的 Border,等价于 \(x\) 的子树
  • 求两个前缀的公共 Border,等价于求两个节点的 \(LCA\)

KMP

fail数组

  • \(fail[i] = Prefix[i]\) 的非平凡最大 Border,在前缀里找 Border。
  • \(fail[1] = 0\)
  • \(Prefix[i]\) 的所有长度大于 1 的 Border。去掉最后一个字母就变成 \(Prefix[i - 1]\) 的Border
    • 故求 \(fail[i]\) 的时候,遍历 \(Prefix[i - 1]\) 的所有 Border,即 \(fail[i - 1], fail[fail[i - 1]], ... , 0\)。检查最后一个字符是否等于 \(S[i]\)
//#define _Border
struct KMP {
    const static int maxn = 1e6 + 10;
    #ifdef _Border
    int h[maxn], e[maxn << 1], ne[maxn << 1], idx;
    void add(int a, int b) {
        e[idx] = b, ne[idx] = h[a], h[a] = idx++;
    }
    KMP() {init(); }
    void init() {   // 初始化
        memset(h, -1, sizeof h), idx = 0;
    }
    #endif
    int fail[maxn], len;
    /* 获取 s 串的 fail[] 数组*/
    void get_ne(const char* s, int _len) {
        fail[1] = 0, len = _len;
        for (int i = 2, j = 0; i <= len; i ++) {
            while (j && s[j + 1] != s[i]) j = fail[j];
            if (s[i] == s[j + 1]) j++;
            fail[i] = j;
        }
        #ifdef _Border
        for (int i = 1; i <= len; i++) {add(fail[i], i); }
        #endif
    }
    /* 查找 p 在 s 中的匹配位置 */
    void match(const char* s, const char* p, int lens, int lenp) {
        for (int i = 1, j = 0; i <= lens; i++) {
            while (j && s[i] != p[j + 1]) j = fail[j];
            if (s[i] == p[j + 1]) j++;
            if (j == lenp) {
                // Extension
                j = fail[j];
            }
        }
    }
    void debug() {
        for (int i = 1; i <= len; i++)
            printf("fail[%d] = %d\n", i, fail[i]);
    }
} kmp;

字符串哈希

对字符串前缀进行哈希映射,从左到右,高位到低位,映射成 \(P\) 进制数。

  • 数组写成 unsigned long long 可以达到自动取模的目的
  • 得到 \([l,r]\) 字串哈希值,\(value=h_r-h_{l-1}\times p^{r-l+1}\)

单模

#define DT ull
const ll base = 911, mod = 4294967291ull;
DT B[N], h[N];

DT get(char s) {
    return s - 'a' + 1;
}

void init(const char* s, int len) {
    B[0] = 1;
    for (int i = 1; i <= len; i++) {
        B[i] = B[i - 1] * base % mod;
        h[i] = h[i - 1] * base % mod + get(s[i]);
        if (h[i] >= mod)
            h[i] -= mod;
    }
}

DT Hash(int l, int r) {
    ll res = h[r] - h[l - 1] * B[r - l + 1] % mod;
    if (res < 0) res += mod;
    return res;
}

多模

#define ull unsigned long long
const int HASH_CNT = 2; // hash 次数 
const ull Prime_Pool[] = {1998585857ul,23333333333ul};    
const ull Base_Pool[]={131, 911, 146527, 19260817,91815541};
const ull Mod_Pool[]={4294967291ull, 1000000181, 1000000403, 29123};
struct StringHash{
    const static int maxn = 2e5 + 10;
    #define DT ull
    int n;
    const DT mod, Base, flag; 
    DT base[maxn], sum[maxn];
    // sum[i] = s[i]+s[i-1]*Seed+s[i-2]*Seed^2+...+s[1]*Seed^(i-1)
    // hasher[HASH_CNT] = {Hash(N, Base_Pool[0], Mod_Pool[0]), Hash(N, Base_Pool[1], Mod_Pool[1])};;
    StringHash(int _n, const int _Base, const int _mod, const int _flag = 0): n(_n), Base(_Base), mod(_mod), flag(_flag) {
        base[0] = 1;
        for (int i = 1; i <= n; ++i){
            base[i] = base[i - 1] * Base % mod;
        }
    }
    DT get(char c) {
        return c - 'a' + 1;
    }
    // flag == 0,对字符串正着hash, flag==1
    void indexInit(const char* s, int len){       
        if(flag == 0)
            for (int i = 1;i <= len; ++i){
                sum[i] = (sum[i - 1] * Base % mod + get(s[i]));
                if (sum[i] >= mod) sum[i] -= mod;
            }
        else{
            for (int i = 1;i <= len; ++i){
                sum[i] = (sum[i - 1] * Base % mod + get(s[len - i + 1]));
                if (sum[i] >= mod) sum[i] -= mod;
            }
        }
    }
    DT getHash(int l, int r){
        ull res = sum[r] - sum[l - 1] * base[r - l + 1] % mod;
        if (res < 0) res += mod;
        return res;
    }
};  

AC自动机

原理

  1. 很多都类比 KMP 算法,但有自己的优化(Trie图优化)

应用及扩展

  • 运用ne[]数组特性,同时cnt[]记录所有单词前缀的出现次数,从后向前累加所有cnt,记录单词节点位置,可以统计字符串中所有单词的出现次数

Trie图优化模板

#define FAIL
const int N = 1e6 + 10, sigma = 26;
int endpos[N];
#ifdef FAIL 
int h[N], e[N << 1], ne[N << 1], _idx;
int dfn[N], rk[N], tot, sz[N];
void add(int a, int b) {
    e[_idx] = b, ne[_idx] = h[a], h[a] = _idx++;
}
#endif
struct ACAM {
    int cnt[N], tr[N][sigma], idx, fail[N];
    void clear() {
        for (int i = 0; i <= idx; i++)
            cnt[i] = 0, memset(tr[i], 0, sizeof tr[i]), memset(fail, 0, sizeof fail);
        idx = 0;
    }
    void insert(const char* s){
        int p = 0;
        for(int i = 0; s[i]; i++){
            int c = s[i] - 'a';
            if(!tr[p][c]) tr[p][c] = ++idx;
            p = tr[p][c];
            cnt[p]++;
        }
    }
    void build(){
        queue<int> q;
        for(int i = 0; i < sigma; i++)
           if(tr[0][i])
                q.push(tr[0][i]);
        while(q.size()){
            auto u = q.front();
            q.pop();
            #ifdef FAIL 
            add(fail[u], u);
            #endif
            for(int i = 0; i < sigma; i++){
                if(tr[u][i])
                    fail[tr[u][i]] = tr[fail[u]][i], q.push(tr[u][i]);
                else
                    tr[u][i] = tr[fail[u]][i];
            }
        }
    }
    int query(const char * s){
        int u = 0, res = 0;
        for(int i = 0; s[i]; i++){
            u = tr[u][s[i] - 'a'];
            for(int j = u; j && cnt[j] != -1; j = fail[j])
                res += cnt[j], cnt[j] = -1;
        }
        return res;
    }
    int ans[N];
    #ifdef FAIL 
    void dfs(int u) {
        dfn[u] = ++tot;
        rk[u] = tot, sz[u] = 1;
        for (int i = h[u]; ~i; i = ne[i]) {
            int v = e[i];
            dfs(v);
            sz[u] += sz[v];
        }
    }
    #endif
}acam; 

Manacher

const int N = 2e5 + 100;
char s[N];
struct Manacher {
    int lc[N << 1];
    char ch[N << 1];
    int len, n;
    void build(const char* s, int _n) {
        n = _n;
        init(s);
        manacher();
    }
    /* s 1 bas , Manacher manacher(s)*/
    void init(const char *s) {
        ch[n * 2 + 1] = '#';
        ch[0] = '@';
        ch[n * 2 + 2] = '\0';
        for (int i = n; i >= 1; i--) {
            ch[i * 2] = s[i], ch[i * 2 - 1] = '#';
        }
        len = 2 * n + 1;
    }
    void manacher() {
        lc[1] = 1;
        int k = 1;      // k 是 最右子串 回文中心
        for (int i = 2; i <= len; i++) {
            int p = k + lc[k] - 1;  // 最右子串 R
            if (i <= p) {   // 在最右子串内, 继承对称点的回文半径
                lc[i] = min(lc[2 * k - i], p - i + 1);
            }
            else {
                lc[i] = 1;
            }
            while (ch[i + lc[i]] == ch[i - lc[i]])      // 暴力拓展
                lc[i]++;
            if (i + lc[i] > k + lc[k])  
                k = i;
        }
    }
    void debug() {
        for (int i = 1; i <= len; i++) 
            i == len ? cout << ch[i] << "\n" : cout << ch[i];
        for (int i = 1; i <= len; i++) {
            cout << "lc[" << i << "]" << "=" << lc[i] << "\n";
        }
    }
}ches;

Z Algorithm

// z[i]数组表示以i开始的后缀和前缀的最大匹配
int z[N];
void getz(const char *s, int len) {
    for (int i = 1; i <= len; i++) z[i] = 0;
    z[1] = len;
    for (int i = 2, l = 0, r = 0; i <= len; i++) {
        if (i <= r) z[i] = min(z[i - l + 1], r - i + 1);
        while (i + z[i] <= len && s[i + z[i]] == s[z[i] + 1]) ++ z[i];
        if (i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
    }
}

最小表示法

int min_show(const char *s, int len) {   // 下标从0开始
    int k = 0, i = 0, j = 1;
    while (k < len && i < len && j < len)
    {
        if (s[(i + k) % len] == s[(j + k) % len]) k++;
        else {
            s[(i + k) % len] > s[(j + k) % len] ? i = i + k + 1 : j = j + k + 1;
            if (i == j) i++;
            k = 0;
        }
    }
    return min(i, j);
}

回文自动机 PAM

每个节点代表一个回文串, \(l[]\) 表示对应的回文串长度,每次向下转移长度 + 2

const int maxn = 3e5 + 100;
struct PAM {
    // basic
    int s[maxn], now;
    int nxt[maxn][26], fail[maxn], l[maxn], last, tot;
    // extension
    int num[maxn]; /*节点代表的所有回文串出现次数*/
    void clear() {
        // 1节点:奇数长度root 0节点:偶数长度root, fail[0] = 1, fail[1] = 0, now 初始指向 1(奇根)
        s[0] = l[1] = -1;
        fail[0] = tot = now = 1;
        last = l[0] = 0;
        memset(nxt[0], 0, sizeof nxt[0]);
        memset(nxt[1], 0, sizeof nxt[1]);
    }
    PAM() { clear(); }
    int newnode(int ll) {       // 建立长度为 ll 的节点。
        tot++;
        memset(nxt[tot], 0, sizeof nxt[tot]);
        fail[tot] = num[tot] = 0;
        l[tot] = ll;
        return tot;
    }
    int get_fail(int x) {
        while (s[now - l[x] - 2] != s[now - 1])
            x = fail[x];
        return x;
    }
    void add(int ch) {
        s[now++] = ch;
        int cur = get_fail(last);
        if (!nxt[cur][ch]) {
            int u = newnode(l[cur] + 2);
            fail[u] = nxt[get_fail(fail[cur])][ch];
            nxt[cur][ch] = u;
        }
        last = nxt[cur][ch];
        num[last]++;        // 成为 last 出现次数 + 1
    }
    void build() {
        // fail[i]<i,拓扑更新可以单调扫描。
        for (int i = tot; i >= 2; i--) {
            num[fail[i]] += num[i];
        }
        num[0] = num[1] = 0;
    }
    void init(const char* str, int len) {
        for (int i = 1; i <= len; i++)
            add(s[i] - 'a');
    }
    long long query();
} pam;
long long PAM::query(){
    long long res =1;
    for (int i = 2; i <= tot; i++){
        res = max(res, 1LL * l[i] * num[i]);
    }
    return res;
}

后缀数组 SA

前缀倍增

\(O(logn)\)

#define RMQ
const int maxn = 1e6 + 10;
struct SA {
    #ifndef RMQ
    struct Segment_Tree {
        #define ls u << 1
        #define rs u << 1 | 1
        int min_val[maxn << 2];
        void pushup(int u) {
            min_val[u] = min(min_val[ls], min_val[rs]);
        }
        void build(int u, int l, int r, int* h) {
            if (l == r) {
                min_val[u] = h[l];
                return ;
            }
            int mid = (l + r) >> 1;
            build(ls, l, mid, h), build(rs, mid + 1, r, h);
            pushup(u);
        }
        int query(int u, int l, int r, int ql, int qr) {
            if (l > qr || ql > r) return 0x3f3f3f3f;
            if (ql <= l && r <= qr) return min_val[u];
            int mid = (l + r) >> 1;
            return min(query(ls, l, mid, ql, qr), query(rs, mid + 1, r, ql, qr));
        }
    }segtree;
    #else
    int st[maxn][20], lg[maxn];
    void init_st() {
        for (int i = 2; i < maxn; i++) lg[i] = lg[i / 2] + 1;
        for (int i = 1; i <= n; ++i) st[i][0] = height[i];
        for (int j = 1; (1 << j) <= n; ++j) {
            for (int i = 1; i <= (n - (1 << j) + 1); ++i) {
                st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
            }
        }
    }
    #endif
    /*height[i] = lcp(S[sa[i]],S[sa[i-1]]), h[i]=height[rk[i]], h[i]>=h[i-1]-1, lcp(s[i],s[j])=min(height[rk[i]+1],...,height[rk[j]])*/
    int n, sa[maxn], rk[maxn], id[maxn], cnt[maxn], height[maxn], px[maxn];   
    void get_sa(const char* s, int _n) {    // get sa and height
        n = _n;
        int m = 300, p = 0;      // m 是值域, 初始化为字符集大小
        for (int i = 0; i <= m; i++) cnt[i] = 0;
        for (int i = 1; i <= n; ++i) cnt[rk[i] = (int)s[i]] ++; // 先对1个字符大小的子串进行计数排序
        for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
        for (int w = 1; w <= n; w <<= 1, m = p, p = 0) { // m=p 就是优化计数排序值域
            for (int i = n - w + 1; i <= n; ++i) // 第二关键字无穷小先放进去
                id[++p] = i;
            for (int i = 1; i <= n; ++i) 
                if (sa[i] > w) id[++p] = sa[i] - w; // 顺次放入 s[sa[i]-w] 的第二关键字排名
            for (int i = 0; i <= m; ++i) cnt[i] = 0;
            for (int i = 1; i <= n; ++i) ++cnt[rk[i]], px[i] = rk[id[i]];  
            for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
            for (int i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i];
            for (int i = 1; i <= n; ++i) swap(rk[i], id[i]);
            rk[sa[1]] = p = 1;
            for (int i = 2; i <= n; ++i) {
                rk[sa[i]] = (id[sa[i]] == id[sa[i - 1]] && id[sa[i] + w] == id[sa[i - 1] + w] ? p : ++p);
            }
            if (p >= n) {       // 排名已经更新出来了
                break;
            }
        }
    }
    void get_height(const char* s){
        for (int i = 1, k = 0; i <= n; ++i) {       // 获取 height数组
            if (k) --k;
            int j = sa[rk[i] - 1];
            while (s[i + k] == s[j + k]) ++k;
            height[rk[i]] = k;
        }
#ifdef _DEBUG
        for (int i = 1; i <= n; ++i) 
            cout<<"height["<<i<<"] = "<<height[i]<<endl;
        }
#endif
    }

    void init() {
#ifndef RMQ
        segtree.build(1, 1, n, height);
#else
        init_st();
#endif
    }
    int get_lcp(int x, int y) {
        int rkx = rk[x], rky = rk[y];
        if (rkx > rky) swap(rkx, rky);
        rkx++;
#ifndef RMQ
        int lcp = segtree.query(1, 1, n, rkx, rky);
#else
        int k = lg[(rky - rkx + 1)];
        int lcp = min(st[rkx][k], st[rky - (1 << k) + 1][k]);
#endif

#ifdef _DEBUG
        cout<<"[getlcp] x="<<x<<" y="<<y<<" rkx="<<rkx<<" rky="<<rky<<" lcp="<<lcp<<endl;
#endif
        return lcp;
    }
}sa;

后缀自动机

const int maxn = 1e6 + 10;
struct SAM {
    //basic
    const char* s;
    int last, cnt, len;
    int nxt[maxn * 2][26],fa[maxn * 2],l[maxn * 2];
    //extension
    int cntA[maxn * 2], id[maxn * 2];/*辅助拓扑更新*/
    int num[maxn * 2];/*每个节点代表的所有串的出现次数*/
    SAM () { clear(); }
    void clear() {
        last = cnt = 1, l[1] = fa[1] = 0, memset(nxt[1], 0, sizeof nxt[1]);
    }
    void init(const char * str, int _len) {
        s = str, len = _len;
        for (int i = 1; i <= _len; i++)
            extend(str[i] - 'a');
    }
    void extend(int c) {
        int p = last, np = ++cnt;
        memset(nxt[cnt], 0, sizeof nxt[cnt]);
        l[np] = l[p] + 1, last = np;
        while (p && !nxt[p][c]) nxt[p][c] = np, p = fa[p];
        if (!p) fa[np] = 1;
        else {
            int q = nxt[p][c];
            if (l[q] == l[p] + 1) fa[np] = q;
            else {
                int nq = ++cnt;
                l[nq] = l[p] + 1;
                memcpy(nxt[nq], nxt[q], sizeof(nxt[q]));
                fa[nq] = fa[q], fa[np] = fa[q] = nq;
                while (nxt[p][c] == q) nxt[p][c] = nq, p = fa[p];
            }
        }
    }
    void build() {
        memset(cntA, 0, sizeof cntA);
        memset(num, 0, sizeof num);
        for (int i = 1; i <= cnt; i++) cntA[l[i]]++;
        for (int i = 1; i <= cnt; i++) cntA[i] += cntA[i - 1];
        for (int i = cnt; i >= 1; i--) id[cntA[l[i]]--] = i;
        /*更新主串节点*/
        int temp = 1;
        for (int i = 1; i <= len; i++) {
            num[temp = nxt[temp][s[i] - 'a']] = 1;
        }
        /*拓扑更新*/
        for (int i = cnt; i >= 1; i--) {
            // basic
            int x = id[i];
            num[fa[x]] += num[x];
            // extension
        }
        // extension
    }
    void debug(){
        for (int i = cnt; i >= 1; i--){
            printf("num[%d]=%d l[%d]=%d fa[%d]=%d\n",i,num[i],i,l[i],i,fa[i]);
        }
    }
}sam;

广义后缀自动机 (在线)

const int maxn = 2e6 + 10;
struct EXSAM {
    //basic
    const char* s;
    int cnt, len;
    int nxt[maxn * 2][26],fa[maxn * 2],l[maxn * 2];
    //extension
    queue<int> q;
    int cntA[maxn * 2], id[maxn * 2];/*辅助拓扑更新*/
    int num[maxn * 2];/*每个节点代表的所有串的出现次数*/
    EXSAM () { clear(); }
    void clear() {
        cnt = 1, l[1] = fa[1] = 0, memset(nxt[1], 0, sizeof nxt[1]), memset(num, 0, sizeof num);
    }
    int extend(int c, int last) {
        if (nxt[last][c]) {
            int p = last, x = nxt[p][c];
            if (l[p] + 1 == l[x]) { num[x] = 1; return x;}
            int y = ++cnt;
            l[y] = l[p] + 1;
            memcpy(nxt[y], nxt[x], sizeof nxt[x]);
            while (p && nxt[p][c] == x) nxt[p][c] = y, p = fa[p];
            fa[y] = fa[x], fa[x] = y;
            num[y] = 1;
            return y;
        }
        int p = last, np = ++cnt;
        memset(nxt[cnt], 0, sizeof nxt[cnt]);
        l[np] = l[p] + 1, last = np;
        while (p && !nxt[p][c]) nxt[p][c] = np, p = fa[p];
        if (!p) fa[np] = 1;
        else {
            int q = nxt[p][c];
            if (l[q] == l[p] + 1) fa[np] = q;
            else {
                int nq = ++cnt;
                l[nq] = l[p] + 1;
                memcpy(nxt[nq], nxt[q], sizeof(nxt[q]));
                fa[nq] = fa[q], fa[np] = fa[q] = nq;
                while (nxt[p][c] == q) nxt[p][c] = nq, p = fa[p];
            }
        }
        num[np] = 1;
        return np;
    }
    void build() {
        memset(cntA, 0, sizeof cntA);
        for (int i = 1; i <= cnt; i++) cntA[l[i]]++;
        for (int i = 1; i <= cnt; i++) cntA[i] += cntA[i - 1];
        for (int i = cnt; i >= 1; i--) id[cntA[l[i]]--] = i;
        /*拓扑更新*/
        for (int i = cnt; i >= 1; i--) {
            // basic
            int x = id[i];
            num[fa[x]] += num[x];
            // extension
        }
        // extension
    }

}sam;

广义后缀自动机 (离线)

const int maxn = 1e6 + 10;

struct Trie {
    int idx, fa[maxn * 26], son[maxn][26], c[maxn * 26]; 
    Trie() {idx = 1;}
    void insert(const char* s) {
        int p = 1;
        for (int i = 1; s[i]; i++) {
            int u = s[i] - 'a';
            if (!son[p][u]) son[p][u] = ++idx, fa[idx] = p, c[idx] = u;
            p = son[p][u];
        }
    }
}Tr;

struct SAM {
    //basic
    const char* s;
    int cnt, len;
    int nxt[maxn * 2][26],fa[maxn * 2],l[maxn * 2];      
    queue<int> q;
    //extension
    int cntA[maxn * 2], id[maxn * 2];/*辅助拓扑更新*/
    int num[maxn * 2];/*每个节点代表的所有串的出现次数*/
    int pos[maxn * 2];  // Trie 上节点在 SAM 上对应的节点编号
    SAM () { clear(); }
    void clear() {
        cnt = 1, l[1] = fa[1] = 0, memset(nxt[1], 0, sizeof nxt[1]);
    }
    void init() {
        for (int i = 0; i < 26; i++) if (Tr.son[1][i]) q.push(Tr.son[1][i]);
        pos[1] = 1;
        while (!q.empty()) {
            int t = q.front(); q.pop();
            pos[t] = extend(Tr.c[t], pos[Tr.fa[t]]);
            for (int i = 0; i < 26; i++) if (Tr.son[t][i]) q.push(Tr.son[t][i]);
        }
    }
    int extend(int c, int last) {
        int p = last, np = ++cnt;
        memset(nxt[cnt], 0, sizeof nxt[cnt]);
        l[np] = l[p] + 1, last = np;
        while (p && !nxt[p][c]) nxt[p][c] = np, p = fa[p];
        if (!p) fa[np] = 1;
        else {
            int q = nxt[p][c];
            if (l[q] == l[p] + 1) fa[np] = q;
            else {
                int nq = ++cnt;
                l[nq] = l[p] + 1;
                memcpy(nxt[nq], nxt[q], sizeof(nxt[q]));
                fa[nq] = fa[q], fa[np] = fa[q] = nq;
                while (nxt[p][c] == q) nxt[p][c] = nq, p = fa[p];
            }
        }
        return np;
    }
    void build() {
        memset(cntA, 0, sizeof cntA);
        memset(num, 0, sizeof num);
        for (int i = 1; i <= cnt; i++) cntA[l[i]]++;
        for (int i = 1; i <= cnt; i++) cntA[i] += cntA[i - 1];
        for (int i = cnt; i >= 1; i--) id[cntA[l[i]]--] = i;
        /*更新主串节点*/
        int temp = 1;
        for (int i = 1; i <= len; i++) {
            num[temp = nxt[temp][s[i] - 'a']] = 1;
        }
        /*拓扑更新*/
        for (int i = cnt; i >= 1; i--) {
            // basic
            int x = id[i];
            num[fa[x]] += num[x];
            // extension
        }
        // extension
    }
    void debug(){
        for (int i = cnt; i >= 1; i--){
            printf("num[%d]=%d l[%d]=%d fa[%d]=%d\n",i,num[i],i,l[i],i,fa[i]);
        }
    }
}sam;

搜索与图论

DFS 深搜

枚举

指数型枚举

选与不选

// 指数型枚举
void dfs(int u){
    if(u > n){
        for(auto t: v)
            printf("%d ", t);
        printf("\n");
        return ;
    }
    v.pb(u);        // 选
    dfs(u + 1);
    v.pob();
    dfs(u + 1);     // 不选
}

\(n\) 皇后枚举示例

两者区别主要在dfs参数所维护的信息导致搜索的顺序不同

按每一个格子选与不选进行枚举

void dfs(int x, int y, int s){
    if(y == n){     // 超出横轴,换到下一层继续枚举
        y = 0;
        x++;
    }
    if(x == n){
        if(s == n){
            for(int i = 0; i < n; i++)
                puts(g[i]);
            puts("");
        }
        return;     // 要return防止爆栈
    }
    // 不选当前格子
    dfs(x, y + 1, s);
    
    // 选当前格子
    if(!dg[x + y] && !udg[x - y + n] && !row[x] && !col[y]){
        g[x][y] = 'Q';
        dg[x + y] = udg[x - y + n] = row[x] = col[y] = true;
        dfs(x, y + 1, s + 1);
        dg[x + y] = udg[x - y + n] = row[x] = col[y] = false;
        g[x][y] = '.';
    }
}

按列进行枚举

// 按行枚举
void dfs(int u){
    if(u == n){
        for(int i = 0; i < n; i++)
            puts(g[i]);
        puts("");
    }
    for(int i = 0; i < n; i++){        // 枚举列
        if(!col[i] && !udg[u - i + n] && !dg[u + i]){
            g[u][i] = 'Q';
            col[i] = udg[u - i + n] = dg[u + i] = true;
            dfs(u+1);
            col[i] = udg[u - i + n] = dg[u + i] = false;
            g[u][i] = '.';
        }
    }
}

组合型枚举

// 组合型枚举
void dfs(int u){
    if(ans.size() > m || ans.size() + (n - u + 1) < m)  // 指数型枚举剪枝
        return;
    if(ans.size() == m){
        for(auto t: ans)
            printf("%d ", t);
        printf("\n");
        return;
    }
    ans.pb(u);
    dfs(u + 1);
    ans.pob();
    dfs(u + 1);
}

排列型枚举

// 排列型枚举
void dfs(int u){        // u 记录当前枚举到哪个位置了
    if(u > n){      // 枚举完了所有就输出然后return
        for(auto t : ans)
            printf("%d ", t);
        printf("\n");
        return;
    }
    for(int i = 1; i <= n; i++){
        if(!st[i]){
            st[i] = true;
            ans.pb(i);
            dfs(u + 1);
            ans.pob();      // 回溯
            st[i] = false;
        }
    }
}

BFS 宽搜

bfs 一圈一圈地慢慢进行搜索。
解题考虑以下两点:

  1. 队列中节点维护的信息是什么
  2. 如何解决bfs中距离的表示

通用模板
BFS一定注意进入队列中的单个节点维护的信息,这很重要。

int bfs(){
    queue<PII> q;
    q.push({1,1});
    d[1][1] = 0;        // 标记距离左上角的距离,同时初始化为-1,不为-1则代表已访问
    while(q.size()){
        auto t = q.front();
        q.pop();
        for(int i = 0; i < 4; i++){
            int x= t.first + dx[i], y = t.second + dy[i];
            if(x < 1 || x > n || y < 1 || y > m || d[x][y] != -1 || g[x][y] == 1)
                continue;
            d[x][y] = d[t.first][t.second] + 1;
            q.push({x,y});
        }
    }
    return d[n][m];
}

01 BFS

// 01边权双端队列,边权为0放队头,边权为1放队尾,写法类似dijkstra

int bfs(){
    deque<PII> q;
    int dx[4] = {-1, -1, 1, 1}, dy[4] = {-1, 1, 1, -1};     // 偏移量
    int ix[4] = {-1, -1, 0, 0}, iy[4] = {-1, 0, 0, -1};     // 点与边坐标转换
    string cs = "\\/\\/";
    memset(st, false, sizeof st);
    memset(dist, 0x3f, sizeof dist);
    q.push_front({0, 0});
    dist[0][0] = 0;
    while(q.size()){
        auto t = q.front();
        q.pop_front();
        if(t.x == n && t.y == m)
            return dist[n][m];
        if(st[t.x][t.y])
            continue;
        st[t.x][t.y] = true;
        for(int i = 0; i < 4; i++){
            int x = t.x + dx[i], y = t.y + dy[i];
            if(x < 0 || x > n || y < 0 || y > m || st[x][y])        // 点数比行列数多一排一列
                continue;
            int w = (g[t.x + ix[i]][t.y + iy[i]] != cs[i]);
            int d = dist[t.x][t.y] + w;
            if(d < dist[x][y]){
                dist[x][y] = d;
                if(!w)
                    q.push_front({x, y});
                else
                    q.push_back({x, y});
            }
        }
    }
}

树与图的存储

  • 树是一种特殊的图,与图的存储方式相同
  • 对于无向图存储 \(ab\) ,存储两条有向边, \(a->b\), \(b->a\)

邻接矩阵存储

g[a][b] = a->b

邻接表存储

int h[N], e[N], ne[N], w[N], idx;
// 添加边a->b
void add(int a, int b, int c){
    e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
int main(){
    memset(h, -1, sizeof h);
    ...
    return 0;
}

树与图的遍历

时间复杂度
\(O(n + m)\)\(n\) 表示点数, \(m\) 表示边数

DFS 深度优先遍历

dfs求树的重心

例题

AcWing 846. 树的重心

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = N << 1;

int h[N], e[M], ne[M], w[M], idx, n, fa[N], sz[N];

void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int p){
	sz[u] = 1;
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == fa[u]) continue;
		fa[j] = u;
		dfs(j, u);
		sz[u] += sz[j];
	}
}

int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    memset(h, -1, sizeof h);
    cin >> n;
    for(int i = 0; i < n - 1; i++){
    	int a, b;
    	cin >> a >> b;
    	add(a, b), add(b, a);
    }
    dfs(1, -1);
    int ans = 0, mx = 1e9, idx = 0;
    for(int u = 1; u <= n; u++){
    	int t = 0;
    	for(int i = h[u]; ~i; i = ne[i]){
    		int v = e[i];
    		if(v == fa[u])
    			t = max(t, n - sz[u]);
    		else
    			t = max(t, sz[v]);
    	}
    	if(t < mx){
    		mx = t;
    		idx = u;    // 子树最大值最小的点为树的重心
    		ans = t;        
    	}
    }
    cout << ans << endl;
    return 0;
}

两次dfs求树的直径

#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, M = N << 1;
int e[M], ne[M], h[N], w[M], idx, n, m, dist[N];
int start, ed;
vector<int> path;       // 存直径上的所有端点,以直径某一端点为根,深度从小到大

void add(int a, int b, int c){
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int p){
    for(int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if(j == p) continue;
        dist[j] = dist[u] + w[i];
        dfs(j, u);
    }
}

void find(int u, int p) {
    path.pb(u);
    if (u == ed) {
        flag = true;
        return;
    }
    for (int i = h[u]; ~i; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        find(v, u);
        if (flag) return;
    }
    path.pop_back();
}

int main(){
    memseth, -1, sizeof h);
    scanf("%d", &n);
    for(int i = 1; i < n; i++){
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        add(a, b, c), add(b, a, c);
    }
    memset(dist, 0, sizeof dist);
    dfs(1, -1);
    int mx = 0;
    for(int i = 1; i <= n; i++)
        if(dist[i] >= mx){
            mx = dist[i];
            start = i;
        }
    memset(dist, 0, sizeof dist);
    dfs(start, -1);
    for (int i = 1; i <= n; i++) 
        if (dist[i] >= mx) {
            mx = dist[i];
            ed = i;
        }
    find(start, -1);        // 找出直径上端点
    int ans = *max_element(dist + 1, dist + 1 + n);     // ans 为直径长度
    return 0;
}

BFS 宽度优先遍历

有时可以用一个数组同时维护标记和距离功能

#include<queue>

void bfs(){
    queue<int> q;
    q.push(1);      // 首元素入列
    st[1] = true;   // 入列就标记
    while(q.size()){    // 队列不为空
        int t = q.front();  // 取队首
        q.pop();    // 记得pop
        for(int i = h[t]; ~i; i = ne[i]){    // 遍历所有方向
            int j = e[i];
            if(!st[j]){     // 没到过或符合遍历条件
                q.push(j)
                st[j] = true;   // 入队并标记
            }
        }
    }
}

拓扑排序

拓扑图又称有向无环图,拓扑序列中左右顺序一定满足边的起点到终点。

void topo(){
    queue<int> q;
    for(int i = 1; i <= n; i++){
        if(!indegree[i]){
            q.push(i);
            ans.pb(i);
        }
    }
    while(q.size()){
        auto t = q.front();
        q.pop();
        for(int i = h[t]; i != -1; i = ne[i]){
            int j = e[i];
            indegree[j]--;
            if(!indegree[j]){
                q.push(j);
                ans.pb(j);
            }
        }
    }
}

最短路问题

\(n\) 表示点数, \(m\) 表示边数。

  • 单源最短路
    • 所有边权都是正数
      • 朴素 dijkstra 算法 \(O(n^2)\) (适合稠密图)
      • 堆优化 dijkstra 算法 \(O(m*logn)\) (适合稀疏图,\(m\)\(n\) 一个数量级)
    • 存在负权边
      • Bellman-Ford 算法 \(O(n*m)\) (不超过 \(k\) 条边可以使用)
      • SPFA 算法 一般 \(O(m)\), \(O(n*m)\)
  • 多源汇最短路
    • Floyd 算法 \(O(n^3)\)

Dijkstra 算法

朴素Dijkstra \(O(n^2)\)

#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;

const int N = 510;  // 点数
int g[N][N],dist[N]; // g[][]以邻接矩阵形式存取点和边,dist[]存各点到第一个点的最短距离
bool st[N]; // 记录每个点是否已经确定最短路
int n,m;

int dijkstra(){
    memset(dist,0x3f,sizeof dist);

    dist[1] = 0;
    // 循环n次
    for(int i = 1;i <= n;i++){  
        int t = -1;
        // 找出未确定最短路的点中,距第一个点距离最短的点 (贪心证明,硬记即可)
        for(int j = 1;j <= n;j++)
            if(!st[j] && (t == -1 || dist[t] > dist[j]))
                t = j;

        st[t] = true;   // 标记

        for(int j = 1;j <= n;j++)
            dist[j] = min(dist[j],g[t][j] + dist[t]);
    }
    if(dist[n] == 0x3f3f3f3f) return -1;
    else return dist[n];
}


int main(){
    scanf("%d%d",&n,&m);
    memset(g,0x3f,sizeof g);
    while(m--){
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        // 自环自动忽略,重边取最短
        g[x][y] = min(g[x][y],z);
    }

    cout << dijkstra() << endl;
    return 0;
}

堆优化Dijkstra \(O(m*logm = m *logn)\)

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
using namespace std;
const int N = 150010;
// 代表一个点,first 表示 点到根节点最短距离,second 表示 点的编号
typedef pair<int,int> PII;      

int e[N << 1],ne[N << 1],h[N],w[N << 1],idx;   // 稀疏图,采用邻接表的形式存储
int n,m,dist[N];    // dist[] 记录各点到根节点的最短距离
bool st[N];     // 标记点是否已更新出最小值

void add(int a,int b,int c){
    w[idx] = c; e[idx] = b; ne[idx] = h[a];h[a] = idx++;
}

int dijkstra(){
    // 优先队列,递增存储建立小根堆,存储最短边,此种堆不能删除修改指定元素
    priority_queue <PII, vector<PII>, greater<PII>> heap;   
    memset(dist,0x3f,sizeof dist);      // 初始化dist[]为无穷大
    // 第一个节点到自己距离为0
    dist[1] = 0;    
    heap.push({0,1});   

    while(heap.size()){
        PII t = heap.top();     // 取出最短距离的点
        heap.pop();
        int distance = t.first, ver = t.second; // 提取最短距离的编号

        if(st[ver]) continue;   // 若该点已更新了到根节点最短距离,continue
        st[ver] = true;      // 未更新,标记为true

        // 用该点更新其他点的距离
        for(int i = h[ver];i != -1;i = ne[i]){
            int j = e[i];
            // 如果(j 到原点的距离) > (原点到 t 的距离 + t 到 j 的距离) ,更新j的最短距离
            if(dist[j] > distance + w[i]){
                dist[j] = distance + w[i];
                heap.push({dist[j],j});     // 将j点更新的信息加入堆中
            }
        }
    }
    if(dist[n] == 0x3f3f3f3f) return -1;        // 表明n的距离未被更新,不存在连通的边从原点到n
    return dist[n];
}

int main(){
    scanf("%d%d",&n,&m);
    memset(h,-1,sizeof h); // 邻接表初始化表头为-1
    while(m--){
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);     // 后续函数自动使用最短边,可以不考虑重边问题
    }
    cout << dijkstra() << endl;
    return 0;
}

Bellman-Ford 算法

const int N = 510,M = 10010, INF = 0x3f3f3f3f;
int dist[N],last[N]; // dist[]记录到原点的距离,last[]做备份,因为负权边的存在防止串联
int n,m,k;
// 结构体数组存 点和边
struct Edge{
    int a,b,w;
}edges[M];

int bellman_ford(){
    memset(dist, 0x3f,sizeof dist);
    dist[1] = 0;
    for(int i = 0;i < k;i++){       // 最多走 k 条边的最短距离
        memcpy(last,dist,sizeof dist);      // 拷贝
        for(int j = 0;j < m;j++){           // 更新最短距离
            auto e = edges[j];
            // 与dijkstra相似,注意后者距离为备份距离last[e.a]
            dist[e.b] = min(dist[e.b], last[e.a] + e.w);    
        }
    }
    return dist[n];     // 存在负权边,判断dist[n] > INF / 2 就输出 impossible,具体值根据题目范围
}

spfa 算法

spfa求最短路

\(spfa\)\(dijkstra\) 代码相似,与 \(bellman_ford\) 算法主要在于:

  • \(dist\) 值变小的点来更新其他的点,而不是遍历了所有的边
int spfa(){
    memset(dist,0x3f,sizeof dist);      // 初始化距离
    dist[1] = 0;
    queue<int> q;       // 建立队列,并放入第一个点,队列存取的是点的编号
    q.push(1);
    st[1] = true;       // 标记第1个点已经放入队列
    // 队列不空
    while(q.size()){    
        int t = q.front();      // 取队头
        q.pop();
        st[t] = false;  // 标记队头已不在队列中
        for(int i = h[t];i != -1;i = ne[i]){
            int j = e[i];
            if(dist[j] > dist[t] + w[i]){
                dist[j] = dist[t] + w[i];       // 更新距离
                // 如果j未在队列中,则将其放入队列来更新其他点
                if(!st[j]){ 
                    q.push(j);
                    st[j] = true;
                }
            }
        }
    }
    if(dist[n] == 0x3f3f3f3f) return -1;
    else return dist[n];
}

spfa 判断负环

  • 不用初始化 \(dist\) 数组
  • 要把每个点都预先加入队列中
int ne[N],e[N],w[N],dist[N],h[N],idx,n,m;
bool st[N];
int cnt[N];     // 记录到某一点所需要经历的边数
// 代码主体与spfa相同,区别点主要有两处: 无距离dist的初始化,每个点都预先放入队列之中
void add(int a,int b,int c){
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}

bool spfa(){
    queue<int> q;
    for(int i = 1;i <= n;i++){
        q.push(i);
        st[i] = true;
    }
    while(q.size()){
        int t = q.front();
        q.pop();
        st[t] = false;
        for(int i = h[t];i != -1;i = ne[i]){
            int j = e[i];
            if(dist[j] > dist[t] + w[i]){
                dist[j] = dist[t] + w[i];
                cnt[j] = cnt[t] + 1;    // 在前一个点的cnt的基础上+1
                // cnt大于n说明其中至少有n+1个点,由抽屉原理可知,其中一定有环并且是负环
                if(cnt[j] >= n) return true;    
                if(!st[j]){
                    q.push(j);
                    st[j] = true;
                }
            }
        }
    }
    return false;   // 无异常返回false
}

Floyd 算法

  • 基于 \(DP\)\(f[k][i][j]\) 表示从 \(i\) 出发经过 \(k\) 点到 \(j\) 的最短路
  • 处理自环就是把 \(g[i][i] = 0\)

状态转移方程:

\[g[k][i][j] = min(g[k][i][j], g[k-1][i][k] + g[k-1][k][j]) \]

void init(){
    memset(g, 0x3f, sizeof g);
    for(int i = 1; i <= n; i++)
        g[i][i] = 0;    // 自环为0
}

void floyd(){       
    for(int k = 1; k <= n; k++)     
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                g[i][j] = min(g[i][j], g[i][k] + g[k][j]);   // 优化掉了第一维空间
}

最小生成树

\(n\) 代表点数, \(m\) 代表边数

  • \(Prim\) 算法
    • 朴素版 \(Prim\) \(O(n^2)\) (稠密图)
    • 堆优化 \(Prim\) \(O(m*logn)\) (不常用)
  • \(Kruskal\) 算法 \(O(m*logm)\) (稀疏图)

朴素版 Prim 算法 \(O(n^2)\)

#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 510,INF = 0x3f3f3f3f;

int g[N][N],dist[N];    // 邻接矩阵存图中点和边,dist[]存取某点到集合的距离(到集合中所有点的最短距离)
int n,m;
bool st[N]; // 标记数组

int prim(){
    memset(dist,INF,sizeof dist);
    int res = 0;
    for(int i = 0;i < n;i++){
        int t = -1;     // t在for循环里面
        for(int j = 1;j <= n;j++)
            if(!st[j] && (t == -1 || dist[t] > dist[j]))
                t = j;

        if(i && dist[t] == INF) return INF;     // 表明图未连通
        if(i) res += dist[t];       // 先写res后更新dist,排除自环

        // 注意dist[j] 和 g[t][j] 取最小,和dijkstra加以区分
        for(int j = 1;j <= n;j++) dist[j] = min(dist[j],g[t][j]);

        // 标记
        st[t] = true;
    }
    return res;
}

int main(){
    cin >> n >> m;
    memset(g,INF,sizeof g); // 初始化边为无穷大
    while(m--){
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        g[u][v] = g[v][u] = min(g[u][v],w); // 无向图
    }
    int t = prim();
    if(t == INF) puts("impossible");
    else printf("%d",t);
    return 0;
}

堆优化版 Prim 算法 \(O(m * logn)\)

#include <cstring>
#include <iostream>
#include <queue>
using namespace std;

const int MAXN = 510, MAXM = 2 * 1e5 + 10, INF = 0x3f3f3f3f;
typedef pair<int, int> PII;
int h[MAXM], e[MAXM], w[MAXM], ne[MAXM], idx;
bool vis[MAXN];
int n, m;

void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

int Prim()
{
    memset(vis, false, sizeof vis);
    int sum = 0, cnt = 0;
    priority_queue<PII, vector<PII>, greater<PII>> q;
    q.push({0, 1});

    while (!q.empty())
    {
        auto t = q.top();
        q.pop();
        int ver = t.second, dst = t.first;
        if (vis[ver]) continue;
        vis[ver] = true, sum += dst, ++cnt;

        for (int i = h[ver]; i != -1; i = ne[i])
        {
            int j = e[i];
            if (!vis[j]) {
                q.push({w[i], j});
            }
        }
    }

    if (cnt != n) return INF;
    return sum;
}

int main()
{
    cin >> n >> m;
    memset(h, -1, sizeof h);
    for (int i = 0; i < m; ++i)
    {
        int a, b, w;
        cin >> a >> b >> w;
        add(a, b, w);
        add(b, a, w);
    }

    int t = Prim();
    if (t == INF) cout << "impossible" << endl;
    else cout << t << endl; 
}

Kruskal 算法 \(O(m*logn)\)

  • 从小到大对所有边排序
  • 枚举所有边,端点不在同一个连通块内就连起来,只到连完所有点
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 2e5 + 10, INF = 0x3f3f3f3f;
int n, m, p[N];

void init(){        // 初始化并查集
    for(int i = 1; i <= n; i++)
        p[i] = i;
}

int find(int x){        // 查找操作
    if(x != p[x]) return p[x] = find(p[x]);
    return p[x];
}

struct Edge{        // 结构体存边
    int a, b, w;
    bool operator < (const Edge& W) const{
        return w < W.w;
    }
}edges[N];

int kruskal(){
    int cnt = 0, res = 0;;
    for(int i = 0; i < m; i++){
        if(cnt >= n - 1) break;
        int a = edges[i].a, b = edges[i].b, w = edges[i].w;
        a = find(a), b = find(b);
        if(a != b){     // a 和 b 不在一个连通块内
            p[b] = a;
            cnt++;
            res += w;
        }
    }
    if(cnt < n - 1)     // 连接的边数小于 n - 1 表明没有最小生成树
        return INF;
    return res;
}


int main(){
    cin >> n >> m;
    init();
    for(int i = 0; i < m; i++){
        int a, b, c;
        cin >> a >> b >> c;
        edges[i] = {a, b, c};
    }
    sort(edges, edges + m);     // 从小到大排序所有边
    int ans = kruskal();
    if(ans == INF)
        puts("impossible");
    else
        printf("%d", ans);
    return 0;
}

次小生成树

次小生成数是最小生成树的邻集,仅有一条边与最小生成树不同。
两种方式求次小生成树:

  • 求最小生成树,枚举删去最小生成树的每条边,再求依次最小生成树(不能保证得到严格此小生成树) \(O(mlogm + nm)\)
  • 求最小生成树,标记每条边是否在最小生成树中,预处理树中点到点之间路径中边权最大值次大值,枚举非树边,将该边添加到树中,并删去最大边权或者次大边权(枚举边长度与最大边权相等时)的树边,最后得到此小生成树(可以得到严格此小生成树)\(O(mlogm + n^2)\)
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 510, M = 1e4 + 10;
int h[N], e[2 * N], ne[2 * N], w[2 * N], idx;
int dist1[N][N], dist2[N][N], n, m, p[N];       // dist1[][] 存点到点最大边权,dist[][] 存次大边权

// 求次小生成树

struct Edge{
    int a, b, w;
    bool f;
    bool operator < (const Edge & W) const{
        return w < W.w;
    }
}edges[M];

void init(){
    for(int i = 1; i <= n; i++)
        p[i] = i;
}

int find(int x){
    if(x != p[x]) p[x] = find(p[x]);
    return p[x];
}

void add(int a, int b, int c){
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int fa, int maxd1, int maxd2, int d1[], int d2[]){
    d1[u] = maxd1;      // 根据最小生成树的特性,可以直接update
    d2[u] = maxd2;
    for(int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if(j != fa){
            int t1 = maxd1, t2 = maxd2;     // 一定用临时变量存储,后续还要用到maxd1, maxd2;
            if(w[i] > t1){      // 更新最大边和次大边
                t2 = t1;
                t1 = w[i];
            }
            else if(w[i] > t2 && w[i] < t1)     // 更新次大边
                t2 = w[i];
            dfs(j, u, t1, t2, d1, d2);
        }
    }
}

int main(){
    cin >> n >> m;
    init();
    memset(h, -1, sizeof h);
    for(int i = 0; i < m; i++)
        cin >> edges[i].a >> edges[i].b >> edges[i].w;
    sort(edges, edges + m);
    ll sum = 0, res = 1e18;
    for(int i = 0; i < m; i++){
        auto t = edges[i];
        int a = t.a, b = t.b, w = t.w;
        int pa = find(a), pb = find(b);
        if(pa != pb){
            p[pa] = pb;
            sum += w;
            edges[i].f = true;      // 标记为树边
            add(a, b, w), add(b, a, w);     // 建树
        }
    }
    for(int i = 1; i <= n; i++)         // 预处理点到点之间的最大边权和次大边权
        dfs(i, -1, 0, 0, dist1[i], dist2[i]);
        
    for(int i = 0; i < m; i++){
        if(!edges[i].f){
            auto t = edges[i];
            int a = t.a, b = t.b, w = t.w;
            if(w > dist1[a][b])     // 大于最大边
                res = min(res, sum + w - dist1[a][b]);
            else if(w > dist2[a][b])    // 小于等于最大边权,大于次大边权
                res = min(res, sum + w - dist2[a][b]);
        }
    }
    cout << res << endl;
    
    return 0;
}

二分图

  • 染色法 \(O(n + m)\)

  • 匈牙利算法 \(O(m*n)\),实际小于 \(O(m*n)\)

二分图无奇数环,分成的两部分中无边,即原图中一条边的两端点不能是同色,否则不是二分图

染色法判断是否二分图

#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 2e5 + 10;

// 模板思路:二分图无奇数环,分成的两部分中无边,即原图中一条边的两端点不能是同色,否则不是二分图

int n,m;
int e[N],ne[N],h[N],idx;    // 邻接表存图
int color[N];   // 色彩数组,分两个颜色,1和2

void add(int a, int b){
    e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}

bool dfs(int u,int c){
    color[u] = c;
    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i];
        if(!color[j]){
            if(!dfs(j,3 - c)) return false; // 染成 3-c 的颜色,原1则染2,原2则染1
        }
        else if(color[j] == c) return false;    // 和另一点颜色相同,非二分图
    }
    return true;
}

int main(){
    scanf("%d%d",&n,&m);
    memset(h,-1,sizeof h);
    while(m--){
        int a,b;
        scanf("%d%d",&a,&b);
        add(a,b),add(b,a);
    }
    bool flag = true;
    for(int i = 1;i <= n;i++){
        if(!color[i])
            if(!dfs(i,1)){          // 染色出问题
                flag = false;
                break;
            }
    }
    if(flag) puts("Yes");
    else puts("No");

}

二分图的最大匹配数

#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 510,M = 1e5 + 10;

int ne[M],e[M],h[N],idx;    // 邻接表存图
int n1,n2,m;
int match[N];   // 存每个女生对应的是哪个男生
bool st[N];     // 标记对于某个男生来说,女生是否已经匹配过

void add(int a,int b){
    e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}

bool find(int x){
    for(int i = h[x];i != -1;i = ne[i]){
        int j = e[i];
        if(!st[j]){
            st[j] = true;
            // 女生还没有匹配过,或已经匹配的男生可以再找另一个合适的
            if(match[j] == 0 || find(match[j])){    
                match[j] = x;
                return true;
            }
        }
    }
    return false;   // 都不行就返回false
}

int main(){
    scanf("%d%d%d",&n1,&n2,&m);
    memset(h,-1,sizeof h);      // 初始化表头
    while(m--){
        int a,b;
        scanf("%d%d",&a,&b);
        add(a,b);
    }
    int res = 0;
    for(int i = 1;i <= n1;i++){
        memset(st,false,sizeof st);             // 对于每个男生要从前向后匹配,需要memset
        if(find(i)) res++;
    }
    printf("%d\n",res);
    return 0;
}

LCA 最近公共祖先

倍增法求解LCA

  • 先计算出两个节点 \(u,v\) 的深度;
  • \(u,v\) 调整到同一深度
  • 两个节点一起逐级向上跳,直到两节点相等
int dep[N], fa[N][Lg];

void dfs(int u, int p){
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == p) continue;
		dep[j] = dep[u] + 1;
		fa[j][0] = u;
		dfs(j, u);
	}
}

void init(){
	for(int i = 1; i < Lg; i++)		// 先循环跳的次数
		for(int j = 1; j <= n; j++)		// 再循环节点个数
			if(fa[j][i - 1])
				fa[j][i] = fa[fa[j][i - 1]][i - 1];
}

int lca(int a, int b){		// 求点 a 和 点 b 的最近公共祖先
	if(dep[a] < dep[b])
		swap(a, b);
	int d = dep[a] - dep[b];		// 深度大的是 a
	for(int i = 0; i < Lg && d; i++, d /= 2){
		if(d & 1)
			a = fa[a][i];
	}
	if(a == b) return a;
	for(int i = Lg - 1; i >= 0; i--)
		if(fa[a][i] != fa[b][i])
			a = fa[a][i], b = fa[b][i];
	return fa[a][0];
}

Tarjan

Tarjan求割点,割边

struct Tarjan {
    int dfn[N], low[N], id[N], stk[N];
    int tot, root, vertexNum, cnt, top;
    bool cut[N];
    vector<int> v_dcc[N];
    void init() {
        for (int i = 0; i <= idx; i++) h[i] = -1;
        for (int i = 1; i <= vertexNum; i++)
            dfn[i] = low[i] = id[i] = cut[i] = 0;
        for (int i = 0; i <= cnt; i++)  v_dcc[i].clear();
        idx = tot = cnt = vertexNum = top = 0;
    }
    void dfs(int u, int in_edge) {
        dfn[u] = low[u] = ++tot;
        stk[++top] = u;
        for (int i = h[u], flag = 0; ~i; i = ne[i]) {
            int v = e[i];
            if (!dfn[v]) {
                dfs(v, i);
                low[u] = min(low[u], low[v]);
                if (dfn[u] <= low[v]) {     // 会形成新的点双连通分量
                    flag ++, cnt++;
                    if (flag > 1 || u != root) { // 是根的话要两个子树跳不上去,否则是一个子树跳不上去
                        cut[u] = true;
                    }
                    int vv;
                    do {
                        vv = stk[top--];
                        v_dcc[cnt].push_back(vv);
                    } while (vv != v);
                    v_dcc[cnt].push_back(u);        // 点双连通末尾多添加一个割点
                }
            }
            else if (i != (in_edge ^ 1))
                low[u] = min(low[u], dfn[v]);
        }
    }
    void work(int _vertexNum) {
        vertexNum = _vertexNum;
        for (int i = 1; i <= vertexNum; i++) {
            if (!dfn[i]) dfs(root = i, -1);
        }
        // extension
    }
} tarjan;
struct Tarjan {
    int dfn[N], low[N], id[N], stk[N];
    int tot, cnt, top, vertexNum;
    bool bridge[M], ins[N];
    vector<int> e_dcc[N];
    void init() {
        for (int i = 0; i <= idx; i++) h[i] = -1, bridge[i] = false;
        for (int i = 1; i <= vertexNum; i++)
            dfn[i] = low[i] = id[i] = ins[i] = 0;
        for (int i = 0; i <= cnt; i++) e_dcc[i].clear();
        idx = tot = cnt = vertexNum = top = 0;
    }
    void dfs(int u, int in_edge) {
        dfn[u] = low[u] = ++tot;
        stk[++top] = u, ins[u] = true;
        for (int i = h[u]; ~i; i = ne[i]) {
            int v = e[i];
            if (!dfn[v]) {
                dfs(v, i);
                low[u] = min(low[u], low[v]);
                if (dfn[u] < low[v])   
                    bridge[i] = bridge[i ^ 1] = true;
            }
            else if (i != (in_edge ^ 1))
                low[u] = min(low[u], dfn[v]);
        }
        if (dfn[u] == low[u]) {
            ++cnt;
            int v;
            do {
                v = stk[top--];
                ins[v] = false;
                id[v] = cnt;
                e_dcc[cnt].push_back(v);
            } while (u != v);
        }
    }
    void work(int _vertexNum) {
        vertexNum = _vertexNum;
        for (int i = 1; i <= vertexNum; i++) {
            if (!dfn[i]) dfs(i, -1);
        }
        // extension
    }
} tarjan;

Tarjan求强连通分量

struct Tarjan {
    int dfn[N], low[N], stk[N], id[N], cnt, top, tot;
    bool ins[N];
    vector<int> scc[N];
    void dfs(int u) {
        dfn[u] = low[u] = ++tot;
        stk[++top] = u, ins[u] = true;
        for (int i = h[u]; ~i; i = ne[i]) {
            int v = e[i];
            if (!dfn[v]) {
                dfs(v);
                low[u] = min(low[u], low[v]);
            }
            else if (ins[v]) {
                low[u] = min(low[u], dfn[v]);
            }
        }
        if (low[u] == dfn[u]) {
            int v;
            cnt++;
            do {
                v = stk[top--];
                id[v] = cnt;
                scc[cnt].push_back(v);
                ins[v] = false;
            } while (u != v);
        }
    }
    void work(int vertexNum) {
        for (int i = 1; i <= vertexNum; i++) {
            if (!dfn[i])
                dfs(i);
        }
        // extension
    }
}tarjan;

数学

快速幂&高精度

快速幂

typedef long long ll;

// solution1
ll qmi(ll a, ll k, ll p){ 	// a^k mod p
    ll res = 1;
    while(k){
        if(k & 1)
            res = res * a % p; // int res = 1LL * res * a % p
            k >>= 1;
            a = a * a % p; // int a = 1LL * a * a % p;
    }
    return res;
}

// solution 2
int qmi(int a, long long b){
    int res = 1;
    for(; b; b >>= 1){
        if(b & 1)
            res = (long long) res * a % mod;
            a = (long long) a * a % mod;
    }
    return res;
}

// solution 3
long long quick_pow(long long x,long long y, ll p)
{
	if(y==1) return x;
	if(y==0) return 1; //1,2两种情况的代码一定要放在第3种情况之前
	if(y%2==0) return quick_pow(x*x%p,y/2);
	if(y%2!=0) return x*quick_pow(x*x%p,y/2)%p;
}

快速加(高精度乘)

有模数,\(a*b\mod p\)

// solution1 类似于快速幂思想,将k拆为二进制数
typedef long long ll;
ll qmult(ll a, ll k, ll p){
    ll res = 0;
    while(k){
        if(k & 1)
            res = (res + a) % p;
        k >>= 1;
        a = (a * 2) % p;
    }
    return res;
}

// solution2 源自进阶指南,利用ull自动取模,和long double的保留整数
ull mul(ull a, ull b, ull p){
    a %= p, b %= p;
    ull c = (long double) a * b / p;
    ull x = a * b, y = c * p;
    ull res = x - y;
    if(res < 0) res += p;
    return res;
}

无模数,\(a*b\)

vector<int> mul(vector<int> a, int b)
{
    vector<int> c;
    int t = 0;
    for (int i = 0; i < a.size(); i ++ )
    {
        t += a[i] * b;
        c.push_back(t % 10);
        t /= 10;
    }
    while (t)
    {
        c.push_back(t % 10);
        t /= 10;
    }
    // while(C.size()>1 && C.back()==0) C.pop_back();//考虑b==0时才有pop多余的0 b!=0不需要这行
    return c;
}

质数

试除法判定质数

bool is_prime(int x)
{
    if (x < 2) return false;
    for (int i = 2; i <= x / i; i ++ )
        if (x % i == 0)
            return false;
    return true;
}

试除法分解质因数

无论怎么分解,循环范围一定在 \(sqrt(n)\) 之内,大于 \(sqrt(n)\) 的质数单独判断。

void divide(int x)
{
    for (int i = 2; i <= x / i; i ++ )
        if (x % i == 0)	// i是x的因数,并且一定是质数,如果是合数,x一定早就被i因数整除。所以不是合数
        {
            int s = 0;
            while (x % i == 0) x /= i, s ++ ;
            cout << i << ' ' << s << endl;	// i位因数 s为指数
        }
    if (x > 1) cout << x << ' ' << 1 << endl;		// 注意会剩最大因数
    cout << endl;
}

// another version
int p[N], c[N], cnt;     // p[]存所有的质因数, c[]存对应质因数的指数, cnt存质因数个数
void divide(int n) {
    cnt = 0;
    for(int i = 2; i <= n / i; i++){
        if(n % i == 0){
            p[++cnt] = i, c[cnt] = 1;
            while(n % i == 0) n /= i, c[cnt]++;
        }
    }
    if(n > 1)
        p[++cnt] = n, c[cnt] = 1;
}

朴素筛 \(O(nloglogn)\)

int primes[N], cnt;     // primes[]存储所有素数
bool st[N];         // st[x]存储x是否被筛掉

void get_primes(int n)
{
    for (int i = 2; i <= n; i ++ )
    {
        if (st[i]) continue;
        primes[cnt ++ ] = i;
        for (int j = i + i; j <= n; j += i)
            st[j] = true;
    }
}

线性筛(欧拉筛)\(O(n)\)

int primes[N], cnt;     // primes[]存储所有素数
bool st[N];         // st[x]存储x是否被筛掉

void get_primes(int n)
{
    for (int i = 2; i <= n; i ++ )
    {
        if (!st[i]) primes[cnt ++ ] = i;
        for (int j = 0; primes[j] <= n / i; j ++ )
        {
            st[i * primes[j]] = true;
            if (i % primes[j] == 0) break;
        }
    }
}

约数

gcd最大公约数

typedef long long ll;
ll gcd(ll a, ll b){
    return b ? gcd(b, a % b) : a;
}

约数个数及约数之和

  • \(对a\subset Z进行质因数分解:\bf{a=\Pi_{i=1}^kp_i^{n_i}},k是质因数个数,n_i是一个质因数指数\)

  • \(求对\forall a\subset N^+的约数个数:\)

    \(\quad \because对于每个p_i指数取0,1\dots n有n+1种选法,不同的p_i相乘可得到不同的约数。\)

    \(\quad \therefore故由乘法原理得,约数个数为 \prod_{i=1}^k(n_i + 1)\)

  • \(求对\forall a\subset N^+的所有约数之和:\)

    \(\because \forall m是a的约数,m由每个质因数取其不同的指数相乘而得\)

    \(\therefore 约数之和=(p_1^0+p_1^1+\dots+p_1^{n_0})\times\dots\times()(p_k^0+p_k^1+\dots+p_k^{n_k})\)

    \(将上述式子简化得:\)

    \(约数之和=\Pi_{j=1}^k(\sum_{i=0}^np_j^i)\)

unordered_map<int, int> primes;
const int p = 1e9 + 7;
typedef long long ll;
// 约数个数定理:质因数分解a= π(pi^(ai^n)),在每个pi部分有0~ai^n共 n + 1种取法,乘法原理可得结论
// 约数之和定理: a = π(∑pi^a(ij)) (i外,j内)
int main(){
    int n;
    scanf("%d", &n);
    while(n --){
        int x;
        scanf("%d", &x);
        for(int i = 2; i <= x / i; i++){
            while(x % i == 0){          // while
                primes[i] ++;
                x /= i;
            }
        }
        if(x > 1)
            primes[x]++;
    }
    // 约数个数
    ll ans = 1;
    for(auto t: primes){
        ans = ans * (t.second + 1) % p;
    }
    printf("%lld\n", ans);
    // 约数之和
    ll ans = 1;
    for(auto t: primes){
        ll k = 1;
        while(t.second--){
            k = (k * t.first + 1) % p;
        }
        ans = (ans * k) % p;
    }
    printf("%lld\n", ans);
    return 0;
}

欧拉函数

欧拉函数公式实现

\(\phi(n)=n\times\Pi_{i=1}^k(1-\frac{1}{p_i})\color{black},p_i是n的质因子,k是质因数个数\)

\(\phi(n)\) 大小为小于 \(n\)\(n\) 互质整数个数。

typedef long long ll;
// 欧拉函数:对n而言 = n * π(1 - 1/pi), pi为其每一个质因数
ll ans = a;
for(int i = 2; i <= a / i; i++){
    if(a % i == 0){
        ans = ans / i * (i - 1);        // 先除i,防小数,防溢出
        while(a % i == 0) a /= i;
    }
}
if(a > 1) ans = ans / a * (a - 1);
printf("%lld\n", ans);

筛法求欧拉函数 \(O(nloglogn)\)

\(n\) 以内的欧拉函数

const int N = 1e5 + 10;
// 欧拉函数: f(n) = n * π(1 - 1÷pi)
int primes[N], phi[N], cnt = 0;
bool st[N];

void get_eulers(int n){
    phi[1] = 1;         // 1的欧拉函数是1
    for(int i = 2; i <= n; i++){
        if(!st[i]){
            primes[cnt++] = i;
            phi[i] = i - 1;
        }
        for(int j = 0; primes[j] <= n / i; j++){
            st[primes[j] * i] = true;
            if(i % primes[j] == 0){
                phi[i * primes[j]] = phi[i] * primes[j];   
                // primes[j]在phi[i]中出现过,只乘primes[j]就行
                break;
            }
            // primes[j]是i * primes[j] 最小质因子,且不是i的质因子
            phi[i * primes[j]] = phi[i] * (primes[j] - 1);  // = phi[i] * (phi[primes[j])
        }
    }
}

扩展欧几里得算法(exgcd)

exgcd

\(对于一堆整数,求出\color{navy}{\exists x,y\subset Z}\color{black},使得\color{navy}{a\times x+b\times y=gcd(a,b)}\)

// 通过 d = gcd(a, b) = gcd(b, a % b) = ax + by = ax + b(y - b * (a / b))建立等式

void exgcd(int a, int b, int &x, int &y){
    if(!b){
        x = 1, y = 0;
        return;
    }
    // printf("before:a=%d b=%d x=%d y=%d\n", a, b, x, y);
    exgcd(b, a % b, y, x);      // 一直递归到边界,x,y具体值不会变
    // printf("after:a=%d b=%d x=%d y=%d\n", a, b, x, y);
    y -= a / b * x;     // 回溯更新x和y的值
}

解线性同余方程

\(\color{black}对每组数求出一个x_i使\color{purple}{a_i\times x_i\equiv b_i (mod\space m_i)},\color{black}无解输出 \color{red}{impossible}\)

\(方程转化为ax + my = b, 有解则一定有 gcd(a, m) | b\)

int main(){
    int n;
    scanf("%d", &n);
    while(n--){
        int a, b, m;
        scanf("%d%d%d", &a, &b, &m);
        int d = gcd(a, m);
        if(b % d){      // 有解 一定有 gcd(a, m) | b
            puts("impossible");
            continue;
        }
        int x = 0, y = 0;
        exgcd(a, m, x, y);
        printf("%d\n", (ll)x * b / d % m);      // x 乘 b / d 就行,要开开ll,并且%m
    }
    return 0;
}

中国剩余定理

高斯消元

高斯消元解线性方程组

#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 110;
const double eps = 1e-6;
int n;

double a[N][N];

int gauss(){
    int c, r;
    for(c = 0, r = 0; c < n; c ++){
        int t = r;
        // 找到绝对值最大的一行
        for(int i = r; i < n; i++)
            if(fabs(a[i][c]) > fabs(a[t][c]))
                t = i;
        if(fabs(a[t][c]) < eps)
            continue;
        // 交换到第r行,枚举列
        for(int i = c; i < n + 1; i++) swap(a[t][i], a[r][i]);
        // 主元系数为1,注意从后往前推
        for(int i = n; i >= c; i--) a[r][i] /= a[r][c];
        // 将下面每行减去 第r行的 a[i][c] 倍
        for(int i = r + 1; i < n; i++)
            if(fabs(a[i][c]) > eps)
                for(int j = n; j >= c; j--)
                    a[i][j] -= a[i][c] * a[r][j];

        r ++;   // 秩+1
    }
    // 秩 < n
    if(r < n){  
        for(int i = r; i < n; i++)
            if(fabs(a[r][n]) > eps)     // 0=x (x!=0),无解
                return 2;           // 无解
        return 1;       // 无穷解
    }

    // 倒推算出每个x
    for(int i = n - 1; i >= 0; i--)     // a[i][n] -= 枚举首元之后的系数 a[i][j] * 对应的x_n的值 a[j][n]
        for(int j = i + 1; j < n; j++){
            a[i][n] -= a[i][j] * a[j][n];
        }

    return 0;
}

int main(){
    cin >> n;
    for(int i = 0; i < n; i++)
        for(int j = 0; j < n + 1; j++)
            cin >> a[i][j];

    int res = gauss();
    if(res == 1)
        cout << "Infinite group solutions" << endl;
    else if(res == 2)
        cout << "No solution" << endl;
    else{
        for(int i = 0; i < n; i++)
            printf("%.2lf\n", a[i][n]);
    }
    return 0;
}

组合数

\[C_a^b=\frac{a!}{b! * (a - b)!} \]

\[C_a^b = \frac{a*(a-1)*\cdots*(a-b+1)}{b!} \]

基础模型

隔板法

AcWing基础课讲到了 \(4\) 种求组合数的方法,分别应对不同的数据范围

组合数 \(I\)

题意

\(C_a^b \mod (10^9 + 7)\)

数据范围

\(1≤n≤10000, 1≤b≤a≤2000\)

思路

\(C_a^b = C_{a-1}^{b-1} + C_{a-1}^{b}\)

const int N = 2010, mod = 1e9 + 7;
int c[N][N];        // c[a][b] -> C_a^b

void init(){
    for(int i = 0; i < N; i++)
        for(int j = 0; j <= i; j++)     // b <= a
            if(!j) 
                c[i][j] = 1;
            else
                c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;        // % mod
}

int main(){
    int T;
    cin >> T;
    init();
    while(T--){
        int a, b;
        cin >> a >> b;
        cout << c[a][b] << endl;
    }

    return 0;
}

组合数 \(II\)

题意

\(C_a^b \mod (10^9 + 7)\)

数据范围

\(1≤n≤10000, 1≤b≤a≤10^5\)

思路

  • 预处理出分子分母的阶乘
  • \(a/b \equiv a *b^{-1} \mod p\), \(b^{-1}\)\(b\)\(p\) 的乘法逆元
const int N = 100010, mod = 1e9 + 7;
typedef long long ll;
int fact[N], infact[N];

// 预处理阶乘(n * logn)

int qmi(int a, int k){
    int res = 1;
    while(k){
        if(k & 1)
            res = 1LL * res * a % mod;
        a = 1LL * a * a % mod;
        k >>= 1;
    }
    return res;
}

int main(){
    fact[0] = 1, infact[0] = 1;
    for(int i = 1; i < N; i++){
        fact[i] = 1LL * fact[i - 1] * i % mod;
        // a / b 同余 a * b^{-1} % mod -> a * b^{mod - 2}; 费马小定理
        infact[i] = 1LL * infact[i - 1] * qmi(i, mod - 2) % mod;    
        }
    int T;
    cin >> T;
    while(T--){
        int a, b;
        cin >> a >> b;
        // 中间 % mod 防溢出long long
        cout << 1LL * fact[a] * infact[b] % mod * infact[a - b] % mod << endl;         
        }
    
    
    return 0;
}

组合数 \(III\) ( \(Lucas\) 定理)

题意

输入 \(a,b,p\) ,求 \(C_a^b\mod p\) 的值

数据范围

\(1≤n≤20\)
\(1≤b≤a≤10^{18}\)
\(1≤p≤10^5\)

思路

  • 运用 \(Lucas\) 定理:

\[C_a^b\equiv C_{a \% p}^{b \% p} * C_{a / p}^{b / p}\mod p \]

int qmi(int a, int k, int p){
    int res = 1;
    while(k){
        if(k & 1)
            res = 1ll * res * a % p;
        a = 1ll * a * a % p;
        k >>= 1;
    }
    return res;
}

int C(int a, int b, int p){
    if(b > a) return 0; 
    int res = 1;
    // C_a^b = \frac{a*(a-1)*\cdots *(a - b + 1)}{b!};
    for(int i = 1, j = a; i <= b; i++, j--){
        res = 1ll * res * j % p;
        // 费马小定理,a / b 同余 a * b^{-1} 模 p, b^{-1} = b^{p-2}
        res = 1ll * res * qmi(i, p - 2, p) % p;     
        }
    return res;
    }

// Lucas定理:
int lucas(ll a, ll b, int p){       // 参数 a,b 取 long long
    if(a < p && b < p) return C(a, b, p);
    return 1ll * C(a % p, b % p, p) * lucas(a / p, b / p, p) % p;
}


int main(){
    int T;
    cin >> T;
    while(T--){
        ll a, b, p;
        cin >> a >> b >> p;
        cout << lucas(a, b, p) << endl;
    }
    return 0; // 21
}

组合数 \(IV\)

题意

输入 \(a,b\) ,求 \(C_a^b\) 的值
结果很大需要高精度计算(没有 \(\%\) 运算)

数据范围

\(1\leq b\leq a\leq 5000\)

思路

  1. 筛质数
  2. 求出每个质数在阶乘中出现的次数
  3. 遍历质数们,高精度乘法
int primes[N], cnt;
int sum[N];
bool st[N];
int a, b;

// 线性筛
void get_primes(int n){
    for(int i = 2; i <= n; i++){
        if(!st[i])
            primes[cnt++] = i;
        for(int j = 0; primes[j] <= n / i; j++){
            st[i * primes[j]] = true;
            if(i % primes[j] == 0)
                break;
        }
    }
}

// 获取n!含p的指数
int get(int n, int p){
    int res = 0;
    while(n){
        res += n / p;
        n /= p;
    }
    return res;
}

// 高精度乘
vector<int> mult(vector<int> a, int b){
    vector<int> c;
    int t = 0;
    for(int i = 0; i < a.size(); i++){
        t += a[i] * b;
        c.push_back(t % 10);
        t /= 10;
    }
    while(t){
        c.push_back(t % 10);
        t /= 10;
    }
    return c;
}


int main(){
    cin >> a >> b;
    get_primes(a);
    
    for(int i = 0; i < cnt; i++){
        int p = primes[i];
        sum[i] = get(a, p) - get(b, p) - get(a - b, p);
    }
    vector<int> res;
    res.push_back(1);
    
    for(int i = 0; i < cnt; i++){
        int p = primes[i];
        for(int j = 0; j < sum[i]; j++){
            res = mult(res, p);
        }
    }
    // 倒序输出
    for(int i = res.size() - 1; i >= 0; i--)
        cout << res[i];
     
    return 0;
}

卡特兰数

\[方案数=C_{2n}^n - C_{2n}^{n - 1}=\frac{1}{n+1}*C_{2n}^n \]

例题

AcWing889.满足条件的01序列
AcWing129.火车进栈

容斥原理

时间复杂度
通常来说为 \(O(2^n)\)

\[\bigcup_{i=1}^nS_i=\sum_{i=1}^{n}S_i-(S_1\bigcap S_2+S_1\bigcap S_3+\cdots+S_{n-1}\bigcap S_n)+\cdots+(-1)^{n-1}\bigcap_{i=1}^nS_i \]

例题

AcWing890.能被整除的数

#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 20;
int p[N];

int main(){
    int n, m;
    cin >> n >> m;
    for(int i = 0; i < m; i++)
        cin >> p[i];
    ll res = 0;
    for(int i = 1; i < 1 << m; i++){    // 二进制枚举思想
        ll t = 1, cnt = 0;
        for(int j = 0; j < m; j++){
            if(i >> j & 1){
                t *= p[j];      // 此处可能溢出,开longlong
                if(t > n){
                    t = -1;
                    break;
                }
                cnt ++;
            }
        }
        if(t != -1){
            if(cnt & 1)     // 根据集合数量来决定加还是减
                res += n / t;  
            else
                res -= n / t;
        }
    }
    cout << res << endl;
    return 0;
}

矩阵快速幂

如何利用递推构造 base 矩阵是关键

template <typename T, size_t N> struct Mat {
	int len;
	Mat() { memset(data, 0, sizeof(data)); len = N; }

	T *operator[](int i) { return data[i]; }

	const T *operator[](int i) const { return data[i]; }

	T add(T a, T b){
		return (a + b) % mod;
	}

	Mat &operator += (const Mat &o) {
		for (int i = 0; i < len; ++i) 
			for (int j = 0; j < len; ++j) 
				data[i][j] = add(data[i][j], o[i][j]);
		return *this;
	}

	Mat operator + (const Mat &o) const {
		return Mat(*this) += o;
	}

	Mat &operator -= (const Mat &o) {
		for (int i = 0; i < len; ++i) 
			for (int j = 0; j < len; ++j) 
				data[i][j] = add(data[i][j], -o[i][j]);
		return *this;
	}

	Mat operator-(const Mat &o) const {
		return Mat(*this) -= o;
	}

	Mat operator*(const Mat &o) const {
		static T buffer[N];
		Mat result;
		for (int j = 0; j < len; ++j) {
			for (int i = 0; i < len; ++i) 
				buffer[i] = o[i][j];
			for (int i = 0; i < len; ++i) 
				for (int k = 0; k < len; ++k) 
					result[i][j] += (data[i][k] * buffer[k]) % mod;
		}
		return result;
	}

	Mat power(unsigned long long k) const {
		Mat res;
		for (int i = 0; i < len; ++i) 
			res[i][i] = T{1};
		Mat a = *this;
		while (k) {
			if (k & 1ll) 
				res = res * a;
			a = a * a;
			k >>= 1ll;
		}
		return res;
	}

	private:
		T data[N][N];
};

博弈论

性质

  • 有限性: 无论两人怎样决策,都会在有限步后决出胜负。
  • 公平性: 即两人进行决策所遵循的规则相同。

P/N状态

P-position: P代表Previous,上一次行动的人有必胜策略的局面是P-position,也就是“先手必败”

N-position: N代表Next,当前行动的人有必胜策略的局面是N-position,也就是“先手可保证必胜”

P点: 即必败点,在双方都在最优策略下,玩家位于此点必败
N点: 即必胜点,在双方都在最优策略下,玩家位于此点必胜

胜态与必败态

  • 若面临末状态者为获胜则末状态为胜态,否则末状态为必败态
  • 一个局面是胜态的充要条件是该局面进行某种决策后成为必胜态
  • 一个局面是必败态的充要条件是该局面无论进行哪种决策均会成为胜态

nim游戏

每堆石子数异或和不为0,则先手必赢,否则输。

SG函数

SG定理

\[SG(b1,b2) = SG(b1)\oplus SG(b2) \]

#include<unordered_set>
// SG定理:sg(b1, b2) = sg(b1) ^ sg(b2)
// SG函数
int sg(int x){
    if(f[x] != -1) return f[x];
    unordered_set<int> S;
    for(int i = 0; i < x; i++){
        for(int j = 0; j <= i; j++)
            S.insert(sg(i) ^ sg(j));
    }
    for(int i = 0; ; i++)       // sg函数异或值可能大于x, 不要写条件
        if(!S.count(i))
            return f[x] = i;
}

动态规划

\(dp\) 核心思想
一个集合代表了多种状态,形成优化

线性DP

数字三角形模型

例题:传纸条

#include<iostream>
#include<cstring>
using namespace std;
const int N = 55;
int w[N][N];
int f[N + N][N][N];         // 与取方格状态相同,不再重复
                            // 状态表示:f[k][i1][i2]表示走到(i1, k-i1)(i2, k-i2)取到最大数
                            // 当i1 + j1 == i2 + j2时,两条路径可能重合,终点取f[2n][n][n];
                            // 状态转移:f[k][i1][i2] = max(f[k-1][i1-1][i2],..[i1][i2-1],...[i1][i2],...[i1-1][i2-1]) + w;
                            // w取值重合时取w[i1][k - i1],不重合再加上w[i2][k - i2];
int main(){
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= m; j++)
            scanf("%d", &w[i][j]);
    for(int k = 2; k <= n + m; k++){
        for(int i1 = 1; i1 <= n; i1++){
            for(int i2 = 1; i2 <= n; i2++){
                int j1 = k - i1, j2 = k - i2;
                if(j1 < 1 || j1 > m || j2 < 1 || j2 > m) continue;
                int t = w[i1][j1];
                if(i1 != i2) t += w[i2][j2];
                int &x = f[k][i1][i2];
                x = max(x, f[k - 1][i1][i2] + t);
                x = max(x, f[k - 1][i1 - 1][i2] + t);
                x = max(x, f[k - 1][i1][i2 - 1] + t);
                x = max(x, f[k - 1][i1 - 1][i2 - 1] + t);
                // cout << x << " " << endl;
            }
        }
    }
    printf("%d", f[n + m][n][n]);       // 两个坐标都要取到n,因为i1 i2都是横坐标
    return 0;
}

LIS模型

朴素 LIS \(O(n^2)\)

  • 状态表示: \(f[i]\) 表示以 \(a[i]\) 为结尾的最大上升子序列长度
  • 状态转移: \(f[i] = max(f[i], f[j] + 1) if a[i] >= a[j] and j < i\)

二分+贪心优化 \(O(nlogn)\)

  • 维护单调递增的序列,序列长度就是目标答案,每次找到一个数插入到大于等于它的第一个位置并替换之
int len = 0;
q[0] = -2e9;
for(int i = 1; i <= n; i++){
    if(a[i] > q[len]) q[++len] = a[i];
    else{
        int p = lower_bound(q, q + len, a[i]) - q;
        q[p] = a[i];
    }
}

背包模型

输入统一 \(n\) 表示物品个数, \(m\) 表示背包体积大小

滚动数组优化:

  • 用到\(f[i-1]\) 的状态,将体积从大到小枚举
  • 用到 \(f[i]\) 的状态,将体积从小到大枚举

注意⚠️:需要求具体方案不能进行滚动数组优化

01背包

核心:选与不选

一般题意:

​ 给定 \(n\) 个物品,每个物品价值为 \(w_i\) 且只有一个,背包容积为 \(v\) ,求出所有方案的最大价值

状态表示: \(f[i][j]\) 表示前 \(i\) 个物品中,体积不超过 \(j\) 的所有方案的最大价值

状态计算: \(f[i][j]=max(f[i-1][j],f[i -1][j - v_i] + w_i)\)

// 滚动数组(一维空间优化)
for(int i = 1; i <= n; i++)
    for(int j = m; j >= v[i]; j--)
        f[j] = max(f[j], f[j - v[i]] + w[i]);

完全背包

状态转移证明:

  • 用一个值 \(f[i][j -v] + w\) 代表了前缀所有最大值,递推优化

\(\because f[i][j] = max(f[i-1][j], f[i-1][j-v]+w,f[i-1][j-2v]+2w,\dots )\)

\(\& f[i][j-v] = max(f[i-1][j-v], f[i-1][j-v]+w, ...)\)

\(\therefore f[i][j] = max(f[i][j], f[i][j-v_i] + w_i)\)

一般题意:

​ 有 \(N\) 种物品和一个容量是 \(V\) 的背包,每种物品都有无限件可用。第 \(i\) 种物品的体积是 \(v_i\) ,价值是 \(w_i\)

​ 求解选哪些物品进入背包,可使背包在不超容的情况下总价值最大。

状态表示: \(f[i][j]\) 表示在前 \(i\) 个物品中选取总体积不超过 \(j\) 的所有方案的最大价值

状态计算: \(f[i][j]=max(f[i-1][j],f[i][j - v_i] + w_i)\)

// 滚动数组(一维空间优化)
for(int i = 1; i <= n; i++){    // 枚举每个物品组
    for(int j = v; j <= m; j++) // 和01背包不同的是状态有从f[i][j-v[i]]转移,所以j从小到大枚举
        f[j] = max(f[j], f[j - v[i]] + w[i]);
}

多重背包

核心: 针对不同的数据范围,掌握 \(2\) 种对多重背包的优化方法(二进制优化 & 单调队列优化)

一般题意:

​ 有 \(N\) 种物品和一个容量是 \(V\) 的背包。第 \(i\) 种物品最多有 \(s_i\) 件,每件体积是 \(v_i\) ,价值是 \(w_i\)

​ 求解将哪些物品装入背包,可使物品体积总和不超过背包容量,且价值总和最大。

多重背包 \(I\)

数据范围:

\(0<N,V≤100\)

\(0<v_i,w_i,s_i≤100\) (不绝对大概这个量级)

状态表示: $f[i][j] $ 表示前 \(i\) 个物品总体积不超过 \(j\) 的最大价值

状态计算: \(f[i][j] = max(f[i][j], f[i-1][j - k * v_i + k * w_i])\)

把这种多重背包看成特殊的完全背包,或者完全背包看成特殊的多重背包都可以。堆循环即可

(代码略)

多重背包 \(II\) (二进制优化)

时间复杂度:\(O(n\times m\times \log{s})\)

数据范围:

\(0<N≤1000\)

\(0<V≤2000\)

\(0<v_i,w_i,s_i≤2000\)

二进制优化转化为01背包证明:

\(\because 每一个s_i都可以以二进制表示表示成s_i=1+2+4+8+\dots+2^{k-1}+c\quad(c可以不是2的幂)\)

\(设对于每个s_i,c<2^k,拆成k个物品,则每个物品的体积是2^{k_i}\times w\)

\(\therefore 所有物品被拆分成了cnt个物品(cnt=\sum_{j=1}^{n}{k_j}),对每个物品的选择决策就是一个01背包问题\)

const int N = 2010;
int f[N], v[N], w[N], l[N];
int n, m;

int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i <= n; i++)
		cin >> v[i] >> w[i] >> l[i];
	for(int i = 1; i <= n; i++){
		int res = l[i];
		for(int k = 1; k <= res; res -= k, k <<= 1)
			for(int j = m; j >= k * v[i]; j--)
				f[j] = max(f[j], f[j - k * v[i]] + k * w[i]);
		for(int j = m; j >= v[i] * res; j--)
			f[j] = max(f[j], f[j - res * v[i]] + res * w[i]);
	}
	cout << f[m] << endl;
    return 0;
}

多重背包 \(III\) (单调队列优化)

时间复杂度:\(O(n\times m)\)

数据范围:

\(0<N≤1000\)

\(0<V≤20000\)

\(0<v_i,w_i,s_i≤20000\)

单调队列优化证明:

详情参考此篇题解

​ 每次枚举物品组开始时,备份数组保存 \(f[i-1]\) 的信息,进而使用一维状态

​ 把 \(m\) 表示成 \(k*v + j\) \((j=m\%v)\) , \(f[j+k*v]=max(f[j+k_i*v]+(k-k_i)*w)\quad while(k_i≤k)\)

​ 因为想用单调队列维护最大值,所以队头的数不能发生变化,故转化为:

\(f[j+k*v]=max(f[j+k_i*v-k_i*w])+k*w\)

单调队列维护问题,重要的两点:

	1. 维护队列元素的个数,如果不能继续入队,弹出队头元素
	2. 维护队列的单调性,即:$尾值 >= dp[j + k*v] - k*w$
// 队列元素维护的是f[]的下标
int main(){
    int n, V;
    cin >> n >> V;
    while(n--){
        int v, w, s;
        cin >> v >> w >> s;
        memcpy(g, f, sizeof f);         // 用g[]备份dp[i-1][]的状态
        for(int j = 0; j < v; j++){
            int hh = 0, tt = -1;        // 每次枚举j的时候重置队列
            for(int k = j; k <= V; k += v){
                if(hh <= tt && q[hh] < k - v * s)       // 窗口元素过多
                    hh++;
                if(hh <= tt)        // 队列存在,更新f[k], g[q[hh]]本身包含+(q[hh] - j)/v*w;
                    f[k] = max(f[k], g[q[hh]] + (k - q[hh]) / v * w);   
                    
                while(hh <= tt && g[q[tt]] - (q[tt] - j) / v * w < g[k] - (k - j) / v * w)
                    tt--;       // 队尾元素小于等于/小于 g[k]-(k-j)/v*w,剔除队列
                q[++tt] = k;    // 记得入队
            }
        }
    }
    cout << f[V] << endl;
    return 0;
}

** 更 easy 的写法

const int N = 1e4 + 10;
int f[N], q[N][2], n, m, v, w, cnt;

int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i <= n; i++){
		cin >> v >> w >> cnt;
		for(int j = 0; j < v; j++){
			int hh = 0, tt = -1;
			for(int p = j, k = 1; p <= m; k++, p += v){
				int val = f[p] - k * w, ed = k + cnt;
				while(hh <= tt && q[tt][0] <= val) tt--;
				q[++tt][0] = val;
				q[tt][1] = ed;
				f[p] = q[hh][0] + k * w;
				while(hh <= tt && q[hh][1] == k) hh++;
			}
		}
	}
	cout << f[m] << endl;
    return 0;
}

分组背包

先枚举物品组、再从大到小枚举体积、枚举组内决策

一般题意:

\(N\) 组物品和一个容量是 \(V\) 的背包。每组物品有若干个,同一组内的物品最多只能选一个。每件物品的体积是 \(v_{ij}\),价值是 \(w_{ij}\),其中 \(i\) 是组号,\(j\) 是组内编号。 求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。

状态表示: \(f[i][j]\) 表示前 \(i\) 个组,体积不超过 \(j\) 的选取所有方案最大价值

状态计算: \(f[i][j] = max(f[i - 1][j], f[i - 1][j - v_{ik} + w_{ik}])\)

const int N = 1010;
int f[N], v[N], w[N], a[N];
vector<int> c[1001];
int n, m;

int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i <= n; i++){
		cin >> a[i] >> v[i] >> w[i];
		c[a[i]].pb(i);
	}
	for(int i = 1; i <= 1000; i++){
		for(int j = m; ~j; j--){
			for(auto t: c[i])
				if(v[t] <= j)
					f[j] = max(f[j], f[j - v[t]] + w[t]);
		}
	}
	cout << f[m] << endl;
    return 0;
}

二维背包

\(n\) 种物品要放到一个袋子里,袋子的总容量为 \(m\) ,我们一共有 \(k\) 点体力值。第 \(i\) 种物品的体积为 \(v_i\),把它放进袋子里会获得
\(w_i\) 的收益,并且消耗 \(t_i\) 点体力值,每种物品只能取一次。问如何选择物品,使得在物品的总体积不超过 \(m\)
并且花费总体力不超过 \(k\) 的情况下,获得最大的收益?请求出最大收益。
状态表示: \(f[i][j][x]\) 表示选了前 \(i\) 个物品,消耗体积积为 \(j\),消耗体力为 \(x\) 的最大收益
状态转移: \(f[i][j][x] = max(f[i - 1][j][x], f[i - 1][j - v[i]][x - t[i]] + w[i])\)

const int N = 1010;
int n, m, k;
int f[N][N], v[N], w[N], t[N];

int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	cin >> n >> m >> k;
	for(int i = 1; i <= n; i++)
		cin >> v[i] >> w[i] >> t[i];
	for(int i = 1; i <= n; i++)
		for(int j = m; ~j; j--)
			for(int x = k; ~x; x--){
				if(x >= t[i] && j >= v[i])
				f[j][x] = max(f[j][x], f[j - v[i]][x - t[i]] + w[i]);
			}
	cout << f[m][k] << endl;
    return 0;
}

状态机模型DP

  • 建立状态机模型,多开一维表示状态,可以状态之间互相转移

状态压缩 DP

棋盘模型

要点

  • 放置合法问题,考虑从行的角度出发,当前行能不能放和上一行的放置状态有关,二进制压缩结合位运算即可破题

示例:炮兵阵地

#define pb push_back
const int N = 110, M = 1 << 10;
int n, m, cnt[M], g[N], f[2][M][M];     
vector<int> state;

// 以放置在哪一行为决策阶段,空间不够,滚动数组优化,只用到了 i - 1 行状态,直接取 &
// 状态表示:f[i][j][k],表示当前在第 i 行,第 i 行状态为 j, 第 i - 1 行状态为 k 的放置数量最大值
// 状态计算:f[i & 1][b][a] = max(f[i & 1][b][a], f[(i - 1) & 1][a][c] + cnt[b]);


bool check(int x){
    for(int i = 0; i < m; i++)
        if( x >> i & 1 && (x >> (i + 1) & 1 || x >> (i + 2) & 1))       // 1 相邻的4个位置不能有1
            return false;
    return true;
}

int count(int x){
    int res = 0;
    for(int i = 0; i < m; i++)
        if(x >> i & 1) res++;
    return res;
}

int main(){
    cin >> n >> m;
    for(int i = 1; i <= n; i++)
        for(int j = 0; j < m; j++){
            char c;
            cin >> c;
            if(c == 'H')
                g[i] += 1 << j;
        }
    for(int i = 0; i < 1 << m; i++)
        if(check(i)){
            state.pb(i);
            cnt[i] = count(i);
        }
    for(int i = 1; i <= n + 2; i++){
        for(int j = 0; j < state.size(); j++)
            for(int k = 0; k < state.size(); k++)
                for(int u = 0; u < state.size(); u++){
                    int a = state[k], b = state[j], c = state[u];   // a:第i-1行,b:第i行,c:第i-2行
                    if(a & b | b & c | a & c) continue;             // 三个状态不能进行转移
                    if(a & g[i - 1] | b & g[i]) continue;           // 不能放置在山上
                    f[i & 1][b][a] = max(f[i & 1][b][a], f[(i - 1) & 1][a][c] + cnt[b]);
                }
    }
    cout << f[(n + 2) & 1][0][0] << endl;
    return 0;
}

二进制压缩思想

示例:Hamilton路径
给定一张完全图,求从0到n-1经过所有点的最短路径

const int N = 20, M = 1 << 20;
int f[M][N], w[N][N];
int n;

int hamilton(){
    memset(f, 0x3f, sizeof f);
    f[1][0] = 0;        // f[i][j] 状态为i, 停留在j点最短路径
    for(int i = 1; i < 1 << n; i++)         // 枚举所有状态
        for(int j = 0; j < n; j++)      // 枚举状态中停留在哪个点
            if(i >> j & 1)      // 有这个点
                for(int k = 0; k < n; k++)  // 枚举从哪个点移动到 j 
                    if((i ^ 1 << j) >> k & 1)       // 有 k 这个点存在
                        f[i][j] = min(f[i][j], f[i ^ 1 << j][k] + w[k][j]);     // 状态转移
            
    return f[(1 << n) - 1][n - 1];      // 经过了所有点,最后停留在 n - 1 这个点
}

int main(){
    cin >> n;
    for(int i = 0; i < n; i++)
        for(int j = 0; j < n; j++)
            cin >> w[i][j];
    cout << hamilton() << endl;
    return 0;
}

集合模型

示例: 宝藏
(状态压缩DP) O(n2*3n)

状态压缩DP,下文中i是一个 \(n\) 位二进制数,表示每个点是否存在。

状态\(f[i][j]\)表示:

集合:所有包含i中所有点,且树的高度等于j的生成树
属性:最小花费
状态计算:枚举i的所有非全集子集S作为前j - 1层的点,剩余点作为第j层的点。
核心: 求出第j层的所有点到S的最短边,将这些边权和乘以j,直接加到f[S][j - 1]上,即可求出f[i][j]。

证明:
将这样求出的结果记为f'[i][j]

f[i][j]中花费最小的生成树一定可以被枚举到,因此f[i][j] >= f'[i][j];
如果第j层中用到的某条边(a, b)应该在比j小的层,假设a是S中的点,b是第j层的点,则在枚举S + {b}时会得到更小的花费,即这种方式枚举到的所有花费均大于等于某个合法生成树的花费,因此f[i][j] <= f'[i][j]
所以有 f[i][j] = f'[i][j]。

时间复杂度
包含 \(k\) 个元素的集合有 \(C_n^k\) 个,且每个集合有 \(2^k\) 个子集,因此总共有 \(C_n^k * 2^k\) 个子集。\(k\) 可以取 \(0∼n\),则总共有 \(\Sigma_{k=0}^nC_n^k2^k=(1+2)^n=3^n\) ,这一步由二项式定理可得。

对于每个子集需要 \(n^2\) 次计算来算出剩余点到子集中的最短边。

因此总时间复杂度是 O(n2*3n)。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 12, M = 1 << 12, INF = 0x3f3f3f3f;

int n, m;
int d[N][N];
int f[M][N], g[M];

int main()
{
    scanf("%d%d", &n, &m);

    memset(d, 0x3f, sizeof d);
    for (int i = 0; i < n; i ++ ) d[i][i] = 0;

    while (m -- )
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        a --, b --;
        d[a][b] = d[b][a] = min(d[a][b], c);
    }

    for (int i = 1; i < 1 << n; i ++ )
        for (int j = 0; j < n; j ++ )
            if (i >> j & 1)
            {
                for (int k = 0; k < n; k ++ )
                    if (d[j][k] != INF)
                        g[i] |= 1 << k;
            }

    memset(f, 0x3f, sizeof f);
   for (int i = 0; i < n; i ++ ) f[1 << i][0] = 0;

    for (int i = 1; i < 1 << n; i ++ )
        for (int j = (i - 1); j; j = (j - 1) & i)
            if ((g[j] & i) == i)
            {
                int remain = i ^ j;
                int cost = 0;
                for (int k = 0; k < n; k ++ )
                    if (remain >> k & 1)
                    {
                        int t = INF;
                        for (int u = 0; u < n; u ++ )
                            if (j >> u & 1)
                                t = min(t, d[k][u]);
                        cost += t;
                    }

                for (int k = 1; k < n; k ++ ) f[i][k] = min(f[i][k], f[j][k - 1] + cost * k);
            }

    int res = INF;
    for (int i = 0; i < n; i ++ ) res = min(res, f[(1 << n) - 1][i]);

    printf("%d\n", res);
    return 0;
}

区间DP

要点

  • 以区间长度为决策阶段

一维区间DP

一些技巧

  • 环形区间,可以复制区间在后面,破环成链

例题:环形石子合并

const int N = 410, INF = 0x3f3f3f3f;
int w[N], n, s[N];
int f[N][N], g[N][N];

// 两条相同的链模拟环
// 状态表示: f[l][r] 表示从区间 l 到 区间 r 的最小合并费用
// 状态转移: f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);


int main(){
    cin >> n;
    for(int i = 1; i <= n; i ++){
        cin >> w[i];
        w[i + n] = w[i];
    }
    memset(f, 0x3f, sizeof f);
    memset(g, -0x3f, sizeof g);
    for(int i = 1; i <= 2 * n; i++)
        s[i] = s[i - 1] + w[i];
    for(int len = 1; len <= n; len++)
        for(int l = 1; l + len - 1 <= 2 * n; l++){
            int r = l + len - 1;
            if(r > 2 * n) break;
            if(len == 1) f[l][r] = g[l][r] = 0;
            else
                for(int k = l; k <= r; k++){
                    f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
                    g[l][r] = max(g[l][r], g[l][k] + g[k + 1][r] + s[r] - s[l - 1]);
                }
        }
    int max_ = -INF, min_ = INF;
    for(int i = 1; i <= n; i++){
        max_ = max(max_, g[i][i + n - 1]);      // 长度为n的区间,端点差为len - 1
        min_ = min(min_, f[i][i + n - 1]);
    }
    cout << min_ << endl << max_;
    return 0;
}

二维区间DP

例题:棋盘分割

#include<iomanip>
const int N = 9, M = 16;
double f[N][N][N][N][M], X;
int n, g[N][N], s[N][N];

// 二维区间DP
// 状态表示:f[x1][y1][x2][y2][k], 表示在左上顶点为(x1, y1), 右下顶点为(x2, y2) 的矩形内分割 k 个矩形的最小方差
// 状态计算:循环切割的分割线(水平与竖直和继续往哪个部分切)

int get_sum(int x1, int y1, int x2, int y2){
    return s[x2][y2] - s[x1 - 1][y2] - s[x2][y1 - 1] + s[x1 - 1][y1 - 1];
}

double get(int x1, int y1, int x2, int y2){
    double sum = get_sum(x1, y1, x2, y2) - X;
    return sum * sum / n;
}

double dp(int x1, int y1, int x2, int y2, int k){
    double& v = f[x1][y1][x2][y2][k];
    if(v >= 0) return v;
    if(k == 1) return get(x1, y1, x2, y2);
    v = 1e9;
    for(int i = x1; i < x2; i++){
        v = min(v, dp(x1, y1, i, y2, k - 1) + get(i + 1, y1, x2, y2)); 
        v = min(v, dp(i + 1, y1, x2, y2, k - 1) + get(x1, y1, i, y2));
    }
    for(int i = y1; i < y2; i++){
        v = min(v, dp(x1, y1, x2, i, k - 1) + get(x1, i + 1, x2, y2));
        v = min(v, dp(x1, i + 1, x2, y2, k - 1) + get(x1, y1, x2, i));
    }
    return v;
}

int main(){
    cin >> n;
    for(int i = 1; i <= 8; i++)
        for(int j = 1; j <= 8; j++){
            cin >> g[i][j];
            s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + g[i][j];
        }
    memset(f, -1, sizeof f);
    X = s[8][8] * 1.0 / n;
    cout << fixed << setprecision(3) << sqrt(dp(1, 1, 8, 8, n)) << endl;
    return 0;
}

树形DP

要点

  • 以一个节点的状态来表示一个集合

树形DP求树的直径

const int N = 1e4 + 10, M = 2 * N;
int n, idx, e[M], ne[M], h[N], w[M], ans;

// 状态表示:一个点表示以该点为最高点的所有路径最大长度

void add(int a, int b, int c){
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

int dfs(int u, int fa){     // 带入父节点编号,确保是从上到下搜索
    int d1 = 0, d2 = 0;     // 最大、次大到叶子节点的路径
    for(int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if(j == fa) continue;
        int d = dfs(j, u) + w[i];
        if(d >= d1) d2 = d1, d1 = d;
        else if(d > d2) d2 = d;
    }
    ans = max(ans, d1 + d2);
    return d1;              // 返回最长的到叶子节点的路径
}

int main(){
    cin >> n;
    memset(h, -1, sizeof h);
    for(int i = 0; i < n - 1; i++){
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    dfs(1, -1);
    cout << ans << endl;
    return 0;
}

树上背包

换根DP

  • 先考虑以 1 为根做一次 dfs,自底向下,儿子节点信息更新父亲 v -> u
  • 再以 1 为根做一次dfs,在递归前通过父节点信息更新儿子节点信息 u -> v
  • 考虑换根过程,根从 1 换到其他点特殊考虑,其他点互相换根时,设儿子为 x ,根为 y
    • 当要换根到 x 时,先减去 xy 中的贡献,然后重新计算 y 作为子树对新根 x 的贡献。

应用:

  1. 计算树中点到其他点距离和
  2. 树中每个点流大小
  3. 每个点的最长路径
  4. 最大子链和(链权值为点权和, 换根DP)

换根DP计算树中每个点的最长路径

#include<bits/stdc++.h>
typedef long long ll;
typedef unsigned long long ull;
typedef std::pair<int, int> PII;
typedef std::pair<ll, ll> PLL;
#define x first
#define y second
#define pb push_back
#define mkp make_pair
#define endl "\n"
using namespace std;
const int N = 1e5 + 10;
int h[N], e[N << 1], ne[N << 1], idx, n;
int f[N][2][2], g[N];		// f[i[0][0/1] i点最长路径长度(0)节点(1) f[i][1][0/1] 代表次长路径

void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs1(int u, int p){
	for(int i = h[u]; ~i; i = ne[i]){
		int v = e[i];
		if(v == p) continue;
		dfs1(v, u);
		int len = f[v][0][0] + 1;
		if(len > f[u][0][0])
			f[u][1][0] = f[u][0][0], f[u][1][1] = f[u][0][1], f[u][0][0] = len, f[u][0][1] = v;
		else if(len > f[u][1][0])
			f[u][1][0] = len, f[u][1][1] = v;
	}
}

void dfs2(int u, int p){
	for(int i = h[u]; ~i; i = ne[i]){
		int v = e[i];
		if(v == p) continue;
		if(f[u][0][1] == v){
			g[v] = max(f[u][1][0], g[u]) + 1;
		}
		else{
			g[v] = max(f[u][0][0], g[u]) + 1;
		}
		dfs2(v, u);
	}
}
int main(){
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    memset(h, -1, sizeof h);
    cin >> n;
    for(int i = 1; i < n; i++){
    	int a, b;
    	cin >> a >> b;
    	add(a, b), add(b, a);
    }
    dfs1(1, -1);
    dfs2(1, -1);
    for(int i = 1; i <= n; i++){
    	cout << f[i][0][0] + f[i][1][0] + g[i] - min(min(f[i][0][0], f[i][1][0]), g[i]) << endl;
    }
    return 0;
}

数位DP

要点

  • 以树的形式(思考数位)的形式思考数位DP问题
  • 查询 \([l, r]\) 中有多少个数合法,可用前缀和思想 dp(r) - dp(l - 1)dp(x) 求出 \([1,x]\) 中合法的数

状态表示

  • \(f[i][j]\) 表示最高位为 \(j\) ,有 \(i\) 位的方案数
  • \(f[i][j][k]\),表示有 \(i\) 位,最高位为 \(j\),所有位数之和 \(\% p\) 的结果是 \(k\) 的数字个数
  • \(f[i][j][a][b]\), 有 \(i\) 位,最高位为\(j\),数位和 \(\%7=a\) ,数值 \(\%7=b\),合法方案个数

循环迭代

例题:度的数量

  • 求有多少个数满足有 \(k\)\(B\) 进制下的 \(1\)
int K, B, l, r;
int C[N][N];
// 以树的形式(思考数位)的形式思考数位DP问题
void init(){
    for(int i = 0; i < N; i++)
        for(int j = 0; j <= i; j++)
            if(!j) C[i][j] = 1;
            else C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
}

int dp(int n){
    if(!n) return 0;        // 第一步:判断边界
    vector<int> v;
    while(n){               // 第二步:数位转换
        v.push_back(n % B);
        n /= B;
    }
    int res = 0, last = 0;  // 第三部:定义答案res和取了多少B进制下的1的个数last
    for(int i = v.size() - 1; i >= 0; i--){         // 从高到低枚举每一位
        int x = v[i];
        if(x){      // x 为 0 没有讨论的价值,要继续下一位讨论
            res += C[i][K - last];      // x > 0,这一位可以取0,后面可以取 K-last 个1
            if(x > 1){                  // x > 1, 这一位还可以取1,后面 i 位可以取 K-last-1 个 1
                if(K - last - 1 >= 0) res += C[i][K - last - 1];
                break;                  // 因为右边分支大于1,因为题目要求只能有0或1,所以右边不合法不继续讨论了
            }
            else{                       // x = 1, 要记录last++,表示必须得取1了
                last++;
                if(last > K) break;
            }
        }
        if(!i && last == K)             // 取到了最后,数位上有last个1可以取,特殊情况res++
            res++;
    }
    return res;
}


int main(){
    cin >> l >> r >> K >> B;
    init();
    cout << dp(r) - dp(l - 1) << endl;

    return 0;
}

记忆化搜索

ll dp[N][state];        // 状态根据题目性质改变, 例子记录数位中 非零 数位的个数
// 从高位向低位递归
ll dfs(int pos, int cnt, bool lead, bool limit){    // (当前数位, 根据题目需要记录状态, 是否有前导零, 前面的数位是否填满)
	if(pos == -1) return 1;	    // 递归出口, 可能需要判断是否符合题目条件
	if(!limit && !lead && dp[pos][cnt] != -1) return dp[pos][cnt];  // 记忆化, 具体看题目,一般需要 !limit 与 !lead
	int up = limit ? a[pos] : 9;    // 根据前面是否填满, 设立枚举上界
    /*
        灵活修改
        if(cnt == k) up = 0;
    */
	ll res = 0;
	for(int i = 0; i <= up; i++){
        /*
            灵活修改, 进行递归
		    int t = (i != 0);
            res += dfs(pos - 1, cnt + t, lead && i == 0, limit && i == a[pos]); // 注意lead和limit的传递
        */
	}
	if(!limit && !lead) dp[pos][cnt] = res;     // 根据需要, 进行记忆化存储
	return res;
}

单调队列优化DP

  • 单调队列只能保存已经经过的点的信息
    例题:烽火传递
using namespace std;
const int N = 2e5 + 10;
int f[N], w[N], n, m, q[N];

// 状态表示: f[i] 表示前i - 1个烽火已传递,并点燃第 i 烽火的最小代价(这样定义可以具有无后效性)
// 状态计算:f[i] = min(f[j]) + w[i] (i - m <= j < i - 1),用单调队列优化

int main(){
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> w[i];
    int hh = 0, tt = 0;
    for(int i = 1; i <= n; i++){
        if(hh <= tt && q[hh] < i - m) hh++;
        f[i] = f[q[hh]]+ w[i];
        while(hh <= tt && f[q[tt]] >= f[i]) tt--;
        q[++tt] = i;
    }
    while(hh <= tt && q[hh] < n - m + 1) hh++;      // 小技巧,队列后移一位的答案就是 min(f[n - m + 1, n])
    cout << f[q[hh]] << endl;
    return 0;
}

二维滑动窗口

输出仅一个整数,为 \(a×b\) 矩阵中所有 " \(n×n\) 正方形区域中的最大整数和最小整数的差值"的最小值。
例题:理想的正方形

const int N = 1010;
int w[N][N], row_min[N][N], row_max[N][N], q[N], n, k, m;

// 思路:讨论最小值,最大值同理,先对每一行做单调队列记录每个点为长度为k窗口右端点,记录窗口最小值
//       对行的记录结果在列上做单调队列,对于k 行k 列之后的所有点上的数,就是 k * k 的矩阵内的最值

// 获取一个滑动窗口最小值,b[] 存取结果
void get_min(int a[], int b[], int len){
    int hh = 0, tt = 0;
    for(int i = 1; i <= len; i++){
        if(hh <= tt && q[hh] <= i - k) hh++;
        while(hh <= tt && a[q[tt]] >= a[i]) tt--;
        q[++tt] = i;
        b[i] = a[q[hh]];
    }
}

// 获取一个滑动窗口最大值,b[] 存取结果
void get_max(int a[], int b[], int len){
    int hh = 0, tt = 0;
    for(int i = 1; i <= len; i++){
        if(hh <= tt && q[hh] <= i - k) hh++;
        while(hh <= tt && a[q[tt]] <= a[i]) tt--;
        q[++tt] = i;
        b[i] = a[q[hh]];
    }
}

int main(){
    ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> m >> k;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= m; j++)
            cin >> w[i][j];
    for(int i = 1; i <= n; i++){
        get_min(w[i], row_min[i], m);
        get_max(w[i], row_max[i], m);
    }
    int a[N], b[N], c[N], res = 1e9 + 10;
    for(int i = k; i <= m; i++){
        for(int j = 1; j <= n; j++) a[j] = row_min[j][i];
            get_min(a, b, n);
        for(int j = 1; j <= n; j++) a[j] = row_max[j][i];
            get_max(a, c, n);
        for(int j = k; j <= n; j++)
            res = min(res, c[j] - b[j]);
    }
    cout << res << endl;
    return 0;
}

斜率优化DP

  • 斜率优化普遍需要二分/CDQ/平衡树优化

基础凸包优化

例题:任务安排

\(n\) 个 任务排成一个序列,顺序不得改变,其中第 \(i\) 个 任务的耗时为 \(T_i\), 费用系数为 \(C_i\)

现需要把该 \(n\) 个 任务分成若干批进行加工处理

每批次的段头,需要额外消耗 \(S\) 的时间启动机器。每一个任务的完成时间是所在批次的结束时间。

完成一个任务的费用为:从 \(0\) 时刻 到该任务 所在批次结束的时刻 \(t\) 乘以该任务费用系数 \(C\)

  • 运用费用提前思想
const int N = 3e5 + 10;
ll f[N], q[N], t[N], c[N], n, s;        // f[i] 表示执行到任务i的最小费用

// 1. 斜率 = t[i] + s 单调递增
// 2. 横坐标 c[i] 单调递增
// 3. f[i] 在截距上,找到截距最大的点,考虑凸包优化,找到第一个斜率大于 k 的点就是要找的答案,由于单调,采用单调队列优化到O(n)

int main(){
    cin >> n >> s;
    for(int i = 1; i <= n; i++){
        cin >> t[i] >> c[i];
        t[i] += t[i - 1], c[i] += c[i - 1];
    }
    int hh = 0, tt = 0;
    q[0] = 0;
    for(int i = 1; i <= n; i++){            // 队列hh、tt 带上 q[]
        while(hh < tt && (f[q[hh + 1]] - f[q[hh]]) < (t[i] + s) * (c[q[hh + 1]] - c[q[hh]])) hh++;
        int j = q[hh];
        f[i] = f[j] + t[i] * (c[i] - c[j]) + s * (c[n] - c[j]);
        while(hh < tt && (f[i] - f[q[tt - 1]]) * (c[q[tt]] - c[q[tt - 1]]) < (f[q[tt]] - f[q[tt - 1]]) * (c[i] - c[q[tt - 1]])) tt--;
        q[++tt] = i;
    }
    cout << f[n] << endl;
    return 0;
}

基环树上DP

基环树: \(n\) 个点,\(n\) 条边的树。

例题:P2607 [ZJOI2008] 骑士

没有上司的舞会+基环树。

const int N = 1e6 + 10;
const ll INF = 2e18;
ll n, a[N], fa[N], dp[N][2], dp2[N][2];
bool vis[N], oncyc[N];
vector<int> edge[N];

void dfs(int u) {
    vis[u] = true;
    dp[u][1] = a[u];
    for (auto v: edge[u]) {
        if (oncyc[v]) continue;     // 在环上跳过
        dfs(v);
        dp[u][0] += max(dp[v][0], dp[v][1]);
        dp[u][1] += dp[v][0];
    }
}

int main() {
    re(n);
    for (int i = 1, p; i <= n; i++) {
        re(a[i]), re(p);
        edge[p].pb(i);
        fa[i] = p;
    }
    ll ans = 0;
    // 对于每个连通分量求解
    for (int i = 1; i <= n; i++) {
        if (vis[i]) continue;
        // 找出每个环
        int now = i;
        while (!vis[now]) {
            vis[now] = true;
            now = fa[now];
        }
        vector<int> cyc;
        while (!oncyc[now]) {
            oncyc[now] = true;
            cyc.pb(now);
            now = fa[now];
        }
        // 对非环节点树形 dp
        for (auto u: cyc)
            dfs(u);
        // 环上dp
        int m = SZ(cyc);
        ll res = -INF;
        for (int t = 0; t < 2; t++) {
            for (int j = 0; j < 2; j++) {
                if (t == j) dp2[0][j] = dp[cyc[0]][t];
                else dp2[0][j] = -INF;
            }
            for (int i = 1; i < m; i++) {
                dp2[i][0] = max(dp2[i - 1][0], dp2[i - 1][1]) + dp[cyc[i]][0];
                dp2[i][1] = dp2[i - 1][0] + dp[cyc[i]][1];
            }
            if (t == 0) res = max(dp2[m - 1][0], dp2[m - 1][1]);
            else res = max(res, dp2[m - 1][0]);
        }
        ans += res;
    }
    printf("%lld\n", ans);
    return 0;
}

IO 等杂项

查找二进制1个数

__builtin_popcount(); // 32位无符号整形
__builtin_popcountll();

高性能随机数

unsigned seed1 = std::chrono::system_clock::now().time_since_epoch().count(); // #include<chrono>
mt19937_64 rd(seed1); // u64
mt19937 rd(seed1);

cmath库

log2(x);    // 返回 log2(x)
hypot(x, y);    // 返回 x, y 两点间距离

快读

inline long long _read() {
	static long long ans;
	static unsigned int c;
	static bool p;
	for (c = getchar(); c != '-' && (c < '0' || c > '9'); c = getchar());
	if (c == '-') p = false, c = getchar(); else p = true;
	for (ans = 0; c <= '9' && c >= '0'; c = getchar()) ans = ans * 10 + c - '0';
	return p ? ans : -ans;
}

inline void _write(long long ans) {
	static int a[20], n;
	if (ans == 0) {
		putchar('0');
		return;
	}
	if (ans < 0) {
		putchar('-');
		ans = -ans;
	}
	for (n = 0; ans; ans /= 10) a[n++] = ans % 10;
	for (n--; n >= 0; n--) putchar(a[n] + '0');
	return;
}

cin、cout解绑

ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);

cout精度控制

cout << fixed << setprecision(5) << 1.2 << endl;//输出结果为1.20000

C++优化

\(O2\)优化

#pragma GCC optimize(2)

\(O3\)优化

#pragma GCC optimize(3,"Ofast","inline")

测时间

#ifdef DEBUG
clock_t start, end;
start = clock();
#endif

#ifdef DEBUG
end = clock();
printf("End: %ld\n",end);
double elapsedTime = static_cast<double>(end-start) / CLOCKS_PER_SEC ;
printf("CPU PROCESSING TIME: %f",elapsedTime);
#endif

Python3 多行输入

while True:
    try:
        m = int(input().strip())
    except EOFError:
        break

加栈

// Windows 评测
#pragma comment(linker,"/STACK:1024000000,1024000000")
// Windows 本地加入编译命令
-Wl,--stack=64000000000


玄学优化

#pragma GCC diagnostic error "-std=c++11"
#pragma GCC optimize("-fdelete-null-pointer-checks,inline-functions-called-once,-funsafe-loop-optimizations,-fexpensive-optimizations,-foptimize-sibling-calls,-ftree-switch-conversion,-finline-small-functions,inline-small-functions,-frerun-cse-after-loop,-fhoist-adjacent-loads,-findirect-inlining,-freorder-functions,no-stack-protector,-fpartial-inlining,-fsched-interblock,-fcse-follow-jumps,-fcse-skip-blocks,-falign-functions,-fstrict-overflow,-fstrict-aliasing,-fschedule-insns2,-ftree-tail-merge,inline-functions,-fschedule-insns,-freorder-blocks,-fwhole-program,-funroll-loops,-fthread-jumps,-fcrossjumping,-fcaller-saves,-fdevirtualize,-falign-labels,-falign-loops,-falign-jumps,unroll-loops,-fsched-spec,-ffastmath,Ofast,inline,-fgcse,-fgcse-lm,-fipa-sra,-ftree-pre,-ftree-vrp,-fpeephole2",3)
#pragma GCC target("avx","sse2")-
posted @ 2022-03-06 21:39  Roshin  阅读(384)  评论(0编辑  收藏  举报
-->