CSP-S 考前备战——常考知识点串烧

1.树形结构 与 树形dp

PS :在CSP-S 2019,CSP-J 2020,CSP-S 2020,CSP-S 2021 均有考查
此类问题的做题方法就是将问题转化成树上的问题,然后进行深度优先遍历就可以了,之后在深度优先遍历上稍作修改即可。

首先要知道如何写深度优先遍历:

void dfs(int x, int fa)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x);
    }
}

之后就可以开始解决一些简单基础的问题了。

树形 dp 求树的直径

模板:边权树直径

给定一棵无根树

第一行为一个正整数 \(n\) , 表示这颗树有 \(n\) 个节点

接下来的 \(n−1\) 行,每行三个正整数 \(u,v,w\),表示 $u,v $ \((u,v≤n)\) 有一条权值为 \(w\) 的边相连

求树的最长链的距离。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;
using std::max;

const int N = 1e6 + 5;
const int M = 4e6 + 5;

int idx = 0;
int d[N];
int h[N], e[M], ne[M], w[M];
int ans;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x);
        ans = max(ans, d[x] + d[y] + w[i]);
        d[x] = max(d[x], d[y] + w[i]);
    }
}

int main()
{
    int n;

    cin >> n;

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    dfs(1, 1);

    cout << ans << endl;

    return 0;
}

然后就是结点带权求树的直径

模板:点权树直径

给出一棵结点带权的树,定义树上的最长子链为子链上权值和最大的子链,求出最长子链的权值和。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 2e5 + 5;
const int M = 4e5 + 5;

int idx, h[N], e[M], ne[M];
int d[N], w[N];
int ans;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa)
{
    d[x] = w[x];
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x);
        ans = max(ans, d[x] + d[y]);
        d[x] = max(d[x], d[y] + w[x]);
    }
    ans = max(ans, d[x]);
}

int main()
{
    int n;
    cin >> n;

    for (rint i = 1; i <= n; i++)
    {
        cin >> w[i];
    }

    for (rint i = 1; i < n; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }

    dfs(1, -1);

    cout << ans << endl;

    return 0;
}

接下来做几道例题:

[NOI2011] 道路修建

题目的意思为有一棵树,每条边都有边权,每条边有对应的贡献,贡献为边两边联通块大小差的绝对值乘上边权,求所有边的贡献之和。

将每条每条边的贡献表示出来:\(Δ×w_i\)

设一个连通块大小为 d[i] ,那么另一个就是 n - d[i] ,所以 \(Δ = |d[i] - (n - d[i])|\)

所以每条边的贡献就是 \(|2 × d[i] - n| × w_i\)

剩下的就是板子了

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define int long long
#define endl '\n'

using std::cin;
using std::cout;
using std::max;

const int N = 2e6 + 5;
const int M = 4e6 + 5;
const int inf = 1e18;

int idx = 0;
int h[N], ne[M], e[M], w[M], d[N];
int ans, n;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int x, int fa)
{
    d[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x);
        d[x] += d[y];
        ans += abs(2 * d[y] - n) * w[i];
    }
}

signed main()
{
    cin >> n;

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    dfs(1, 1);

    cout << ans << endl;

    return 0;
}

[洛谷月赛] 生活在树上

也是一个很板子的题目

和上一个题目一样,开一个数组 d,然后按照题目要求异或就可以了。

这个题卡 cin,cout ,差评!!!

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define int unsigned long long
#define endl '\n'

using std::cin;
using std::cout;
const int N = 1e6 + 5;
const int M = 2e6 + 5;
const int inf = 1e18;

int idx = 0;
int h[N], ne[M], e[M], w[M], d[N];
int ans, n, m;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int x, int fa)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        d[y] = d[x] ^ w[i];
        dfs(y, x);
    }
}

signed main()
{
    scanf("%llu%llu", &n, &m);

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        scanf("%llu%llu%llu", &a, &b, &c);
        add(a, b, c);
        add(b, a, c);
    }

    dfs(1, 1);

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        scanf("%llu%llu%llu", &a, &b, &c);
        if ((d[a] ^ d[b]) == c)
        {
            puts("YES");
        }
        else
        {
            puts("NO");
        }
    }

    return 0;
}

[洛谷月赛] 如何得到 npy

一共跑三次 dfs,前两次是分别从 st 出发做一次遍历,记录到达每个点的距离。第三次 dfs 计算最小值。(这个从 s 或者 t 谁开始搜都一样)

输出方案的话,带一个 to[] 数组记录就可以了。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 6e5 + 5;
const int M = 2e6 + 5;

int idx = 0;
int h[N], e[M], ne[M], w[M];
int d1[N], d2[N];
int to[M];
long long d[N], ans;
bool v1[N];
bool v2[N];
int n, s, t;
int u[N], v[N];

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int x, int fa)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        if (d[y] > d[x] + w[i])
        {
            d[y] = d[x] + w[i];
            to[y] = x;
        }
        dfs1(y, x);
    }
}

void dfs2(int x, int fa)
{
    ans += d[x];
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs2(y, x);
    }
}

int main()
{
    cin >> n >> s >> t;

    for (rint i = 1; i <= n; i++)
    {
        d[i] = 1e18;
    }

    d[s] = d[t] = 0;

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        u[i] = a;
        v[i] = b;
        add(a, b, c);
        add(b, a, c);
    }

    dfs1(s, s);
    dfs1(t, t);
    dfs2(s, s);

    cout << ans << endl;

    for (rint i = 1; i < n; i++)
    {
        int x = u[i];
        int y = v[i];
        if (x == to[y])
        {
            cout << 2;
        }
        else if (y == to[x])
        {
            cout << 1;
        }
        else
        {
            cout << 0;
        }
    }

    return 0;
}

[HAOI2009] 毛毛虫

结点带权树的直径板子题,唯一的区别就是这个题没有直接把点权给出来,但是大体还是一样的。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;
using std::max;

const int N = 1e6 + 5;
const int M = 4e6 + 5;

int idx = 0;
int d[N];
int h[N], e[M], ne[M], w[N];
int ans;
int n,m;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa)
{
    d[x] = w[x];
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x);
        ans = max(ans, d[x] + d[y] - 1);
        d[x] = max(d[x], d[y] + w[x] - 1);
    }
}

int main()
{
    cin >> n >> m;

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        cin >> a >> b;
        w[a]++;
        w[b]++;
        add(a, b);
        add(b, a);
    }

    dfs(1, -1);

    cout << ans + 1 << endl;

    return 0;
}

[HNOI2014] 米特运输

由于原题面实在太复杂了,所以把题面简化如下:

给一棵树,每个点有一个权值,要求修改一些点的权值,使得同一个父亲的儿子权值必须相同,且父亲的取值必须是所有儿子权值之和。求最小操作次数。

这个题暴力非常好想,建好树后沿着路径累乘即可,

状态转移:

\[d[y] = d[x] × (in[x] - 1) \]

但是题目并没有要求取模!也就是说在最终结果计算完成之前,就已经爆 long long 了。

这个时候我们可以用 log 进行优化。

\[log(a*b) = log(a) + log(b) \]

加法就不会炸了,但是要注意数组开 double (有个傻子开的 int 挂了 90pts,),注意 eps 精度

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <iomanip>
#include <math.h>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 1e6 + 5;
const int M = 4e6 + 5;
const double eps = 1e-9;

int idx, h[N], e[M], ne[M];
int w[N], in[N];
double d[N];
int n;
int ans = 1;
int cnt = 0;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa, double sum)
{
    d[x] = sum + log(w[x]);
    in[x]--;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x, sum + log(in[x]));
    }
}

int main()
{
    cin >> n;
    for (rint i = 1; i <= n; i++)
    {
        cin >> w[i];
    }
    for (rint i = 1; i < n; i++)
    {
        int a, b;
        cin >> a >> b;
        in[a]++;
        in[b]++;
        add(a, b);
        add(b, a);
    }
    in[1]++;    //1号节点没有父亲,单独加一个度数
    dfs(1, 0, 0);
    std::sort(d + 1, d + n + 1);

    for (rint i = 2; i <= n; i++)
    {
        if (d[i] - d[i - 1] < eps)
        {
            cnt++;
        }
        else
        {
            ans = std::max(ans, cnt);
            cnt = 1;
        }
    }

    cout << n - ans << endl;

    return 0;
}

[POI2014] HOT-Hotels

dfs 找到深度,f1[i] 表示有 1 个点距离为 i 的个数,f2[i] 表示有 2 个点距离为 i 的个数.

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 1e5 + 5;
const int M = 2e5 + 5;

int idx, h[N], e[M], ne[M];
int max_depth;
int g[N], f1[N], f2[N];
int n;
int ans;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa, int depth)
{
    g[depth]++;
    max_depth = std::max(max_depth, depth);
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x, depth + 1);
    }
}

signed main()
{
    cin >> n;

    for (rint i = 1; i < n; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }

    for (rint i = 1; i <= n; i++)
    {
        memset(f1, 0, sizeof f1);
        memset(f2, 0, sizeof f2);
        for (rint j = h[i]; j; j = ne[j])
        {
            int y = e[j];
            max_depth = 0;
            dfs(y, i, 1);
            for (rint k = 1; k <= max_depth; k++)
            {
                ans += f2[k] * g[k];
                f2[k] += f1[k] * g[k];
                f1[k] += g[k];
                g[k] = 0;
            }
        }
    }

    cout << ans << endl;

    return 0;
}

[八省联考 2018] 林克卡特树

题目大概意思为有一棵树,割掉恰好 \(k\) 条边然后重新接上恰好 \(k\)\(0\) 权边,最大化新树的直径

\(f[i][j][0/1/2]\) 表示在 \(i\) 节点的子树内,已经有 \(j\) 条完整的链,当前 \(i\) 节点的度数为 \(0/1/2\) 的最大价值

在每个节点的全部转移结束时,进行一次更新: \(f[i][j][0]=max{f[i][j][0],f[i][j−1][1],f[i][j][2]}\)

之后对度数三种情况进行讨论:

\(f[i][j][0] = max(f[i][j][0], f[i][k][0] + f[y][j - k][0]);\)

\(f[i][j][1] = max(f[i][j][1], max(f[i][k][1] + f[y][j - k][0], f[i][k][0] + f[y][j - k][1] + w[i]));\)

\(f[i][j][2] = max(f[i][j][2], max(f[i][k][2] + f[y][j - k][0], f[i][k][1] + f[y][j - k - 1][1] + w[i]));\)

然后处理从子节点出发的链

\(f[i][0][1]=max(f[i][0][1],f[y][0][1]+w[j]);\)

\(f[i][j][0]=max(f[i][j][0],max(f[i][j-1][1],f[i][j][2]));\)

根据这些可以打出 45pts 的暴力

考虑如何优化

\(f[1][x][0]\)\(F(x)\)\(F(x)\) 是一个上凸函数。并且,恒有 \(F(0)=F(n)=0\) 。考虑用一条斜率一定的直线去切这个凸壳。设这条直线为 \(l:y=ax+b\) ,根据上凸函数的特性,斜率一定时,切线的截距一定大于割线。

令切点为 \((t,F(t))\) ,则切线满足 \(a∗t+b=F(t)\) ,即 \(b=F(t)−a∗t\) 。可以发现, \(b\) 的表达式就相当于给每个物品赋上额外的权值 \(−t\) 时,dp 所得的最优解。一次朴素的dp就可以求出截距。

由于函数上凸,发现当斜率不断增大时,对应的切点横坐标也在不断左移。可以二分求相应的斜率。每次 check 的时候顺便记录最优解选取了多少个,从而判断应该如何调整二分区间。

凸壳上多个点共线的情况时,可以限定在 dp 过程中,权值相同时优先取选择次数更小的转移。这样求出来的次数就是当前切点的左边界,二分时判断一下就可以了

#include <iostream>
#include <cstdio>
#include <algorithm>

#define int long long
#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 3e5 + 5;
const int M = 6e5 + 5;
const int inf = 1e18;

int h[N], e[M], ne[M], w[M], idx;
int n, k;
int ans;

struct node
{
    int id, v; // id is count,to limit times;v is val
};

bool operator<(node a, node b)
{
    if (a.v == b.v)
    {
        return a.id < b.id;
    }
    return a.v < b.v;
}

node operator+(node a, node b)
{
    return {a.id + b.id, a.v + b.v};
}

node max(node a, node b)
{
    return b < a ? a : b;
}

node f[N][3];

void clear()
{
    for (rint i = 1; i <= n; i++)
    {
        f[i][0].id = f[i][1].id = f[i][2].id = 0;
        f[i][0].v = f[i][1].v = f[i][2].v = -inf;
    }
}

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int fa, int v)
{
    f[x][0] = {0, 0};
    f[x][1] = {0, 0};
    f[x][2] = {1, v};
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (y == fa)
        {
            continue;
        }
        dfs(y, x, v);
        f[x][2] = max(f[x][2] + f[y][0], f[x][1] + f[y][1] + node{1, w[i] + v});
        f[x][1] = max(f[x][1] + f[y][0], f[x][0] + f[y][1] + node{0, w[i]});
        f[x][0] = max(f[x][0] + f[y][0], f[x][0]);
    }
    f[x][0] = max(f[x][0], max(f[x][1] + node{1, v}, f[x][2]));
}

signed main()
{
    cin >> n >> k;
    k++;

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    int l = -1e12, r = 1e12;

    while (l <= r)
    {
        int mid = (l + r) >> 1;
        clear();
        dfs(1, 0, mid);
        if (f[1][0].id < k)
        {
            l = mid + 1;
        }
        else
        {
            ans = f[1][0].v - mid * k;
            r = mid - 1;
        }
    }

    cout << ans << endl;

    return 0;
}

[NOI2008] 假面舞会

这个题需要注意几个点:

  • 1.建双向边,正边边权为 1,反向边为-1
  • 2.图不一定联通
  • 3.判环判环判环
#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 1e5 + 5;
const int M = 1e6 + 5;
const int inf = 0x3f3f3f3f;

int idx, h[N], e[M], ne[M], w[M];
int d[N];
int maxx, minn;
int ans, res;
bool vis[N];
int n, m;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x, int dist)
{
    if (d[x])
    {
        ans = std::__gcd(ans, abs(d[x] - dist));
        return;
    }
    d[x] = dist;
    vis[x] = 1;
    maxx = std::max(maxx, dist);
    minn = std::min(minn, dist);
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        dfs(y, d[x] + w[i]);
    }
}

signed main()
{
    cin >> n >> m;

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b, 1);
        add(b, a, -1);
    }

    for (rint i = 1; i <= n; i++)
    {
        if (vis[i])
        {
            continue;
        }
        maxx = -inf;
        minn = inf;
        dfs(i, 1);
        res += maxx - minn + 1;
    }

    if (ans) // 找到环了
    {
		if (ans < 3)
        {
            cout << -1 << " " << -1 << endl;
            return 0;
        }
        for (rint i = 3; i <= ans; i++)
        {
            if (ans % i == 0)
            {
                cout << ans << " " << i << endl;
                return 0;
            }
        }
    }
    else //  没有环
    {
        if (res < 3)
        {
            cout << -1 << " " << -1 << endl;
            return 0;
        }
        cout << res << " " << 3 << endl;
    }

    return 0;
}

2.最短路径问题

PS:在 CSP-J 2019,CSP-S 2021 均有考察

Part 1.简单最短路问题

最短路径,顾名思义,在一张图上寻找最短路径,常用的方法就是 dijkstraSPFA , Floyd

下面我们做几个例题练习:

[NOI 导刊] 最长路

显而易见,然我们找最长路。

和最短路有区别吗?没有!唯一的区别就是在比较边长更新的时候,改为如果新边大于当前边,更新;

唯一傻逼的事情是这个题有负的,得用 SPFA 。。。。。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e5 + 5;
const int M = 1e6 + 5;
const int inf = 1e18;

int h[N], e[M], ne[M], dist[N], w[M];
int n, m, s, idx;

queue<int> q;
bool v[N];

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void SPFA()
{
    memset(dist, -1, sizeof dist);
    memset(v, 0, sizeof v);
    dist[s] = 0;
    v[s] = 1;
    q.push(s);
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        v[x] = false;
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            int z = w[i];
            if (dist[y] < dist[x] + z)
            {
                dist[y] = dist[x] + z;
                if (!v[y])
                {
                    q.push(y);
                    v[y] = true;
                }
            }
        }
    }
}

signed main()
{
    cin >> n >> m;

    s = 1;

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
    }

    SPFA();

    cout << dist[n] << endl;

    return 0;
}

[USACO] Piggy Back

题目大概意思是一个无向图,Bessie 从 1 号仓库走到 n 号(每次花费 x), Elsie 从 2 号仓库走到 n 号(每次花费 y),如果两个人走同一条路花费 z,求总花费最小。

跑三遍最短路,别得到 Bessie 从 1 号仓库出发的最短路,Elsie 从 2 号仓库出发的最短路,和从 n 出发到其他每个点的最短路。

在走到相遇点之前,两人一定走的是对应最短路才能减小花费

不难得出 \(ans = min(ans, x * d1[i] + y * d2[i] + z * d3[i])\)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int inf = 0x3f3f3f3f;
const int N = 1e6 + 5;
const int M = 4e6 + 5;

int n, m, s;
int idx, h[N], e[M], ne[M], d1[N], d2[N], d3[N];
bool v[N];
std::priority_queue<std::pair<int, int>> q;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dijkstra(int s, int *dist)
{
    for (rint i = 1; i <= n; i++)
    {
        dist[i] = inf;
    }
    memset(v, 0, sizeof v);
    dist[s] = 0;
    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = 1;
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            int z = 1;
            if (dist[y] > dist[x] + z)
            {
                dist[y] = dist[x] + z;
                q.push(std::make_pair(-dist[y], y));
            }
        }
    }
}

int x, y, z;

int main()
{
    cin >> x >> y >> z >> n >> m;

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }

    dijkstra(1, d1);
    dijkstra(2, d2);
    dijkstra(n, d3);

    int ans = inf;

    for (rint i = 1; i <= n; i++)
    {
        ans = std::min(ans, x * d1[i] + y * d2[i] + z * d3[i]);
    }

    cout << ans << endl;

    return 0;
}

[USACO] Dueling GPS

先跑两次最短路,倒着建图,记录两个 GPS 的最短路。之后再根据之前找到的两个最短路径,建一张新图,边权初始值都是 2,如果这个边在一个最短路上,就给它减一

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 1e6 + 5;
const int M = 4e6 + 5;

int n, m, s;
int f1[N], f2[N];
int idx1, h1[N], e1[M], ne1[M], w1[M], dist1[N];
int idx2, h2[N], e2[M], ne2[M], w2[M], dist2[N];
int idx3, h3[N], e3[M], ne3[M], w3[M], dist3[N];

void add1(int b, int a, int c)
{
    e1[++idx1] = b, w1[idx1] = c, ne1[idx1] = h1[a], h1[a] = idx1;
}

void add2(int b, int a, int c)
{
    e2[++idx2] = b, w2[idx2] = c, ne2[idx2] = h2[a], h2[a] = idx2;
}

void add3(int a, int b, int c)
{
    e3[++idx3] = b, w3[idx3] = c, ne3[idx3] = h3[a], h3[a] = idx3;
}

void dijkstra1(int s)
{
    bool v[114514];
    std::priority_queue<std::pair<int, int> > q;

    memset(dist1, 0x3f, sizeof dist1);
    memset(v, 0, sizeof v);
    dist1[s] = 0;
    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = 1;
        for (rint i = h1[x]; i; i = ne1[i])
        {
            int y = e1[i];
            int z = w1[i];
            if (dist1[y] > dist1[x] + z)
            {
                dist1[y] = dist1[x] + z;
                f1[y] = i;
                q.push(std::make_pair(-dist1[y], y));
            }
        }
    }
}

void dijkstra2(int s)
{
    bool v[114514];
    std::priority_queue<std::pair<int, int> > q;

    memset(dist2, 0x3f, sizeof dist2);
    memset(v, 0, sizeof v);
    dist2[s] = 0;
    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = 1;
        for (rint i = h2[x]; i; i = ne2[i])
        {
            int y = e2[i];
            int z = w2[i];
            if (dist2[y] > dist2[x] + z)
            {
                dist2[y] = dist2[x] + z;
                f2[y] = i;
                q.push(std::make_pair(-dist2[y], y));
            }
        }
    }
}

void dijkstra3(int s)
{
    bool v[114514];
    std::priority_queue<std::pair<int, int>> q;

    memset(dist3, 0x3f, sizeof dist3);
    memset(v, 0, sizeof v);
    dist3[s] = 0;
    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = 1;
        for (rint i = h3[x]; i; i = ne3[i])
        {
            int y = e3[i];
            int z = w3[i];

            if (f1[x] == i)
            {
                z--;
            }
            if (f2[x] == i)
            {
                z--;
            }

            if (dist3[y] > dist3[x] + z)
            {
                dist3[y] = dist3[x] + z;
                q.push(std::make_pair(-dist3[y], y));
            }
        }
        v[x] = 0;
    }
}

int main()
{
    cin >> n >> m;

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c, d;
        cin >> a >> b >> c >> d;
        add1(a, b, c);
        add2(a, b, d);
        add3(a, b, 2);
    }

    dijkstra1(n);
    dijkstra2(n);
    dijkstra3(1);

    cout << dist3[n] << endl;

    return 0;
}

[USACO] Cow Path

Floyd 算法的简单应用,点权和边权同时处理就行。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 2e2 + 5e1 + 5;
const int inf = 0x3f3f3f3f;

int dist[N][N];
int w[N][N];
int n, m, T;
bool flag = true;

int main()
{
    cin >> n >> m >> T;

    memset(dist, 0x3f, sizeof dist);

    for (rint i = 1; i <= n; i++)
    {
        dist[i][i] = 0;
    }

    for (rint i = 1; i <= n; i++)
    {
        cin >> w[i][i];
    }

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        dist[a][b] = dist[b][a] = std::min(dist[a][b], c);
        w[a][b] = w[b][a] = std::max(w[a][a], w[b][b]);
    }

    while (flag == true)
    {
        flag = false;
        for (rint k = 1; k <= n; k++)
        {
            for (rint i = 1; i <= n; i++)
            {
                for (rint j = 1; j <= n; j++)
                {
                    if (dist[i][j] + w[i][j] > dist[i][k] + dist[k][j] + std::max(w[i][k], w[k][j]))
                    {
                        dist[j][i] = dist[i][j] = dist[i][k] + dist[k][j];
                        w[i][j] = w[j][i] = std::max(w[i][k], w[k][j]);
                        flag = true;
                    }
                }
            }
        }
    }

    while (T--)
    {
        int a, b;
        cin >> a >> b;
        long long ans = dist[a][b] + w[a][b];
        cout << ans << endl;
    }

    return 0;
}

Part2.分层图

分层图,顾名思义,就是好几层的最短路。

建图方式其实是有一个模板的:

for (rint i = 1; i <= m; i++)
{
    int x, y, z;
    cin >> x >> y >> z;
    add(x, y, z);
    add(y, x, z);

    for (rint j = 1; j <= k; j++)
    {
        add(x + j * n, y + j * n, ...);
        add(y + j * n, x + j * n, ...);
        add(x + (j - 1) * n, y + j * n, ...);
        add(y + (j - 1) * n, x + j * n, ...);
    }
}

中间就是建造 k 层图,由于第一层图一共只有 n 个点,所以对于当前点加个 n 就是下一层图了。

[BJWC2012] 冻结

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int inf = 0x3f3f3f3f;
const int N = 1e5 + 5;
const int M = 1e6 + 5;

int s, n, m, idx, h[N], e[M], w[M], ne[M], dist[N];
int k;
bool v[N];

std::priority_queue<std::pair<int, int>> q;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dijkstra()
{
    memset(dist, 0x3f, sizeof dist);
    memset(v, 0, sizeof v);

    dist[s] = 0;

    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = true;
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            int z = w[i];
            if (dist[y] > dist[x] + z)
            {
                dist[y] = dist[x] + z;
                q.push(std::make_pair(-dist[y], y));
            }
        }
    }
}

signed main()
{
    cin >> n >> m >> k;

    s = 1;

    for (rint i = 1; i <= m; i++)
    {
        int x, y, z;
        cin >> x >> y >> z;
        add(x, y, z);
        add(y, x, z);

        for (rint j = 1; j <= k; j++)
        {
            add(x + j * n, y + j * n, z);
            add(y + j * n, x + j * n, z);
            add(x + (j - 1) * n, y + j * n, z / 2);
            add(y + (j - 1) * n, x + j * n, z / 2);
        }
    }

    dijkstra();

    int ans = inf;

    for (rint i = 0; i <= k; i++)
    {
        ans = std::min(ans, dist[n + i * n]);
    }

    cout << ans << endl;

    return 0;
}

[USACO09FEB] Revamping Trails

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int inf = 0x3f3f3f3f;
const int N = 1e7 + 5;
const int M = 2e7 + 5;

int s, n, m, idx, h[N], e[M], w[M], ne[M], dist[N];
int k;
bool v[N];

std::priority_queue<std::pair<int, int>> q;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dijkstra()
{
    memset(dist, 0x3f, sizeof dist);
    memset(v, 0, sizeof v);

    dist[s] = 0;

    q.push(std::make_pair(0, s));

    while (!q.empty())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = true;
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            int z = w[i];
            if (dist[y] > dist[x] + z)
            {
                dist[y] = dist[x] + z;
                q.push(std::make_pair(-dist[y], y));
            }
        }
    }
}

signed main()
{
    cin >> n >> m >> k;

    s = 1;

    for (rint i = 1; i <= m; i++)
    {
        int x, y, z;
        cin >> x >> y >> z;
        add(x, y, z);
        add(y, x, z);

        for (rint j = 1; j <= k; j++)
        {
            add(x + j * n, y + j * n, z);
            add(y + j * n, x + j * n, z);
            add(x + (j - 1) * n, y + j * n, 0);
            add(y + (j - 1) * n, x + j * n, 0);
        }
    }

    dijkstra();

    int ans = inf;

    for (rint i = 0; i <= k; i++)
    {
        ans = std::min(ans, dist[n + i * n]);
    }

    cout << ans << endl;

    return 0;
}

Part3.次短路

求出该图的严格第二短路径,我们以模板题目 [USACO06NOV] Roadblocks 为例

PS:由于求次短路,无需 vis 数组。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>

#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 5e5 + 5;
const int M = 1e6 + 5;

int n, m, s;
int idx, h[N], e[M], ne[M], w[M], dist[N];
int need[N];
std::priority_queue<std::pair<int, int> > q;

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dijkstra()
{
    memset(dist, 0x3f, sizeof dist);
    memset(need, 0x3f, sizeof need);
	
    dist[s] = 0;
    q.push(std::make_pair(0, s));
    
    while (!q.empty())
    {
        int x = -q.top().first; 
	int u = q.top().second;
        q.pop();
        
        for (rint i = h[u]; i; i = ne[i])
        {
            int y = e[i];
	    int z = w[i];
            if (dist[y] > x + z)
            {
                need[y] = dist[y]; 
                dist[y] = x + z;
                q.push(std::make_pair(-dist[y], y));
                q.push(std::make_pair(-need[y], y)); 
            }
            else
            {
                if (need[y] > x + z)
                {
                    need[y] = x + z; 
                    q.push(std::make_pair(-need[y], y));
                }
            }
        }
    }
}

int main()
{
    cin >> n >> m;
    s = 1;

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    dijkstra();

    cout << need[n] << endl;

    return 0;
}

Part4.无向图最小环

直接跑 Floyd 即可。我们以洛谷例题 P6175 无向图的最小环问题 为例。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

using std::cin;
using std::cout;

const int N = 2e2 + 5;

int dist[N][N], tmp[N][N];
int n, m;
int ans = 0x1f1f1f1f;

signed main()
{
    memset(dist, 0x1f, sizeof dist);
    memset(tmp, 0x1f, sizeof tmp);

    cin >> n >> m;

    for (rint i = 1; i <= n; i++)
    {
        dist[i][i] = 0;
    }

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        dist[a][b] = dist[b][a] = std::min(dist[a][b], c);
        tmp[a][b] = tmp[b][a] = std::min(tmp[a][b], c);
    }

    for (rint k = 1; k <= n; k++)
    {
        for (rint i = 1; i < k; i++)
        {
            for (rint j = i + 1; j < k; j++)
            {
                ans = std::min(ans, dist[i][j] + tmp[i][k] + tmp[k][j]);
            }
        }

        for (rint i = 1; i <= n; i++)
        {
            for (rint j = 1; j <= n; j++)
            {
                if (dist[i][j] > dist[i][k] + dist[k][j])
                {
                    dist[i][j] = dist[j][i] = dist[i][k] + dist[k][j];
                    dist[j][i] = dist[i][j];
                }
            }
        }
    }
    
    if(ans == 0x1f1f1f1f)
    {
        cout << "No solution." << endl;
        return 0;
    }

    cout << ans << endl;

    return 0;
}

3.动态规划

Part1.线性 dp

线性 dp 是动态规划问题中的一类问题,指状态之间有线性关系的动态规划问题。

这类问题往往是设计状态的时候联想矩阵,之后根据数据范围要求进行优化。

LIS 最长上升子序列

f[i] 表示以 a[i] 为结尾的最长上身子序列的长度

\(f[i] = max(f[j] + 1)\)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e5 + 5;

int n, ans;
int a[N], f[N];

int main()
{
    read(n);

    for (rint i = 1; i <= n; i++)
    {
	read(a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
	f[i] = 1;
	for (rint j = 1; j < i; j++)
	{
	    if (a[i] > a[j])
	    {
		f[i] = std::max(f[i], f[j] + 1);
	    }
        }
	ans = std::max(ans, f[i]);
    }

    printf("%d", ans);

    return 0;
}

但是,在这道题目中,会超时。我们可以用树状数组进行优化。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e5 + 5;

int n, m;
int c[N], a[N], f[N];

int lowbit(int x)
{
    return x & -x;
}

void change(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] = std::max(c[x], y);
    }
}

int ask(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans = std::max(c[x], ans);
    }
    return ans;
}

int res;

int main()
{
    read(n);

    for (rint i = 1; i <= n; i++)
    {
        read(a[i]);
        f[i] = ask(a[i]) + 1;
        change(a[i], f[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        res = std::max(res, f[i]);
    }

    printf("%d", res);

    return 0;
}

LCS 最长公共子序列

f[i][j] 表示两个串从头开始,直到第一个串的第 i 位最多有多少个公共子元素

如果当前的 a[i]b[j] 相同(即是有新的公共元素) 那么
f[i][j]=max(f[i][j],f[i−1][j−1]+1);

如果不相同,即无法更新公共元素:
f[i][j]=max(f[i−1][j],f[i][j−1])

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5;

int f[N][N], a[N * 2], a2[N * 2], n, m;

int main()
{
    std::cin >> n >> m;
	
    for (rint i = 1; i <= n; i++)
    {
	scanf("%d", &a[i]);		
    }

    for (int i = 1; i <= m; i++)
    {
	scanf("%d", &b[i]);		
    }

    for (rint i = 1; i <= n; i++)
    {
	for (rint j = 1; j <= m; j++)
	{
	    f[i][j] = std::max(f[i - 1][j], f[i][j - 1]);
	    if (a[i] == b[j])
	    {
		f[i][j] = std::max(f[i][j], f[i - 1][j - 1] + 1);				
	    }
        }		
    }
		
    printf("%d",f[n][m]);
	
    return 0;
}

同理,我们需要用树状数组进行优化。

f[i] 表示在 b 序列中以 i 结尾的最长公共子序列。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(a) scanf("%d", &a)

const int N = 1e5 + 5;

int n, m;
int c[N], a[N], f[N], p[N];

int lowbit(int x)
{
    return x & -x;
}

void change(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] = std::max(c[x], y);
    }
}

int ask(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans = std::max(c[x], ans);
    }
    return ans;
}

int main()
{
    read(n);

    for (rint i = 1; i <= n; i++)
    {
        read(a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        int x;
        read(x);
        p[x] = i;
    }

    for (rint i = 1; i <= n; i++)
    {
        int x = p[a[i]];
        f[x] = ask(x) + 1;
        change(x, f[x]);
    }

    printf("%d", ask(n));

    return 0;
}

LCIS 最长公共上升子序列

f[i][j] 代表所有 a[1 ~ i]b[1 ~ j] 中以 b[j] 结尾的公共上升子序列的集合

结合前两题,我们可以很显然的打出

for (rint i = 1; i <= n; i ++ )
{
    for (rint j = 1; j <= n; j ++ )
    {
        f[i][j] = f[i - 1][j];
        if (a[i] == b[j])
        {
            int maxx = 1;
            for (rint k = 1; k < j; k ++ )
                if (a[i] > b[k])
                    maxx = max(maxx, f[i - 1][k] + 1);
            f[i][j] = max(f[i][j], maxx);
        }
    }
}

但是空间复杂度过高,因此可以直接将 maxx 提到第一层循环外面,减少重复计算。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(a) scanf("%d", &a)

const int N = 5e3 + 5;

int n;
int a[N], b[N];
int f[N][N];

int main()
{
    read(n);

    for (rint i = 1; i <= n; i++)
    {
        read(a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        read(b[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        int maxx = 1;
        for (rint j = 1; j <= n; j++)
        {
            f[i][j] = f[i - 1][j];
            if (a[i] == b[j])
            {
                f[i][j] = std::max(maxx, f[i][j]);
            }
            if (a[i] > b[j])
            {
                maxx = std::max(maxx, f[i - 1][j] + 1);
            }
        }
    }

    int ans = 0;

    for (rint i = 1; i <= n; i++)
    {
        ans = std::max(ans, f[n][i]);
    }

    printf("%d", ans);

    return 0;
}

[CSP-J2020] 方格取数

f[i][j][→]f[i][j][0]) 表示从当前格子的左边走到当前格子能取到的最大整数之和。

f[i][j][↓] (f[i][j][1]) 表示当前格子的上边走到当前格子能取到的最大整数之和。

f[i][j][↑]f[i][j][2]) 表示当前格子的下边走到当前格子能取到的最大整数之和。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

#define read(a) scanf("%lld", &a)

const int N = 2e3 + 5;

using std::max;

int n, m;
int a[N][N];
int f[N][N][3];

signed main()
{
    read(n);
    read(m);

    memset(f, 0xcf, sizeof f);

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 1; j <= m; j++)
        {
            read(a[i][j]);
        }
    }

    f[1][1][0] = a[1][1];
    f[1][1][1] = a[1][1];
    f[1][1][2] = a[1][1];

    for (rint i = 2; i <= n; i++)
    {
        f[i][1][1] = f[i - 1][1][1] + a[i][1];
    }

    for (rint j = 2; j <= m; j++)
    {
        for (rint i = 1; i <= n; i++)
        {
            f[i][j][0] = max(f[i][j - 1][0], max(f[i][j - 1][1], f[i][j - 1][2])) + a[i][j];
            if (i != 1)
            {
                f[i][j][1] = max(f[i - 1][j][0], f[i - 1][j][1]) + a[i][j];
            }
        }
        for (rint i = n - 1; i >= 1; i--)
        {
            f[i][j][2] = max(f[i + 1][j][0], f[i + 1][j][2]) + a[i][j];
        }
    }

    printf("%lld", max(f[n][m][0], max(f[n][m][1], f[n][m][2])));

    return 0;
}

[NOIP2000] 方格取数

f[i][j][len] 表示一个人走到横坐标为 \(i\) 的点,第二个人走到横坐标为 \(j\) 的点,一共前进了 \(len\) 次能去得的最大值。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d",&x) 

using std::max;
using std::min;

const int N = 1e2 + 5;

int n, a[N][N], f[N][N][N];
int x, y, z;

int main()
{
    read(n);
    read(x);
    read(y);
    read(z);
    
    while (x != 0)
    {
        a[x][y] = z;
        read(x);
        read(y);
        read(z);
    }
    
    f[1][1][0] = a[1][1];
    
    for (rint l = 1; l <= 2 * n; l++)
    {
        for (rint i = 1; i <= n; i++)
        {
            for (rint j = 1; j <= n; j++)
            {
		if (i == j)
                {
                    //可以合并起来写,但是复习的时候不方便看,就拆开算了 
		    f[i][j][l] = max(f[i][j][l], f[i - 1][j - 1][l - 1] + a[i][l + 2 - i]);
                    
                    f[i][j][l] = max(f[i][j][l], f[i - 1][j][l - 1] + a[i][l + 2 - i]);
                    
                    f[i][j][l] = max(f[i][j][l], f[i][j - 1][l - 1] + a[i][l + 2 - i]);
                    
                    f[i][j][l] = max(f[i][j][l], f[i][j][l - 1] + a[i][l + 2 - i]);
                }
                if (i != j)
                {
                    f[i][j][l]=max(f[i][j][l],f[i-1][j-1][l-1]+a[i][l+2-i]+a[j][l+2-j]);

                    f[i][j][l]=max(f[i][j][l],f[i][j-1][l-1]+a[i][l+2-i]+a[j][l+2-j]);

                    f[i][j][l]=max(f[i][j][l],f[i-1][j][l-1]+a[i][l+2-i]+a[j][l+2-j]);

                    f[i][j][l]=max(f[i][j][l],f[i][j][l-1]+a[i][l+2-i]+a[j][l+2-j]);
                }
            }				
	}
    }

    printf("%d",f[n][n][2 * (n - 1)]);
    
    return 0;
}

AcWing.1017 怪盗基德的滑翔翼

非常板子的一道题,跑一边最长上升子序列,再跑一遍最长下降子序列,这个题就切了。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;

int n, m, a[N];
int f1[N], f2[N];

int main()
{
    scanf("%d", &m);
	
    for(rint k = 1; k <= m; k++)
    {
	scanf("%d", &n);
	for(rint i = 1; i <= n; i++)
	{
	    scanf("%d", &a[i]);
	    f1[i] = f2[i] = 1;
	}
	f1[0] = f2[0] = 1;
		
	for(rint i = 2; i <= n; i++)
	{
	    for(rint j = 1; j < i; j++)
	    {
		if(a[j] < a[i])
		{
		    f1[i] = std::max(f1[i], f1[j] + 1);
		}
	    }
        }
		
	for(rint i = n - 1; i >= 1; i--)
	{
	    for(rint j = n; j > i; j--)
	    {
		if(a[j] < a[i])
		{
		    f2[i] = std::max(f2[i], f2[j] + 1);
		}
	    }
	}	
		
	int ans = 0;
	for(rint i = 1; i <= n; i++)
	{
	    ans = std::max(ans, std::max(f1[i], f2[i]));
	}	
	printf("%d\n", ans);
    }
	
    return 0;
}

[HZOI] 回文

给定一个 \(n\)\(m\) 列的字符矩阵 \(A\) ,以 \((1,1)\) 为起点,只能往下或者往右走,求到 \((n,m)\) 的回文路径有多少条。\((n,m≤500)\)

\(f_{i,j,k}\) 为起点向下走了 \(i\) 步,终点向上走了 \(j\) 步,此时一共都走了 \(k\) 步的目前可能回文方案数。

每个点可以往两个方向走,可以得出转移方程:

\(f_{i,j,k}=f_{i,j,k−1}+f_{i−1,j,k−1}+f_{i,j−1,k−1}+f_{i−1,j−1,k−1}\)

\(k\) 的上界 \(⌊(n+m−2)/2⌋\) 范围下,两个起点的最终位置要么相交为一个点,要么就是相邻。相交的直接加就行了,相邻的以起点的角度看,终点在右边可以,在下边也可以,方案相加即可。

不难发现 \(k\) 在转移过程中涉及到的值永远只差一,可以滚动数组优化空间。(代码中状态定义把这个放到了第一维)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5; 
const int mod = 998244353;

int n, m;
int k = 0;
char a[N][N];
int f[2][N][N];

int main()
{
    scanf("%d%d", &n, &m);
	
    for(rint i = 1; i <= n; i++)
    {
	for(rint j = 1; j <= m; j++)
	{
	    std::cin >> a[i][j];
        }
    }
	
    if(a[1][1] != a[n][m])
    {
	puts("0");
	return 0;
    }
	
    f[0][0][0] = 1;
	
    for(rint len = 1; len <= (n + m - 2) / 2; len++)
    {
	k = k ^ 1;
	int t = std::min((int)len, n - 1);
		
	for(rint i = 0; i <= t; i++)
	{
	    for(rint j = 0; j <= t; j++)
	    {
                if (a[1 + i][1 + len - i] != a[n - j][m - len + j])
                {
		    continue;
	        }
				
	        f[k][i][j] = (f[k][i][j] + f[k ^ 1][i][j]) % mod;
	        if(i) f[k][i][j] = (f[k][i][j] + f[k ^ 1][i - 1][j]) % mod;			
	        if(j) f[k][i][j] = (f[k][i][j] + f[k ^ 1][i][j - 1]) % mod;
	        if(i && j) f[k][i][j] = (f[k][i][j] + f[k ^ 1][i - 1][j - 1]) % mod;
	    }
        }
        for(rint i = 0; i <= t; i++)
        {
	    for(rint j = 0; j <= t; j++)
	    {
	        f[k ^ 1][i][j] = 0;
	    }
        }
    }
	
    long long ans = 0;
	
    for(rint i = 1; i <= n; i++)
    {
	for(rint j = 1; j <= m; j++)
	{
            if (i - 1 + j - 1 != (n + m - 2) / 2)
            {
		continue;
	    }
            if (n - i + m - j == (n + m - 2) / 2)
            {
                ans = (ans + f[k][i - 1][n - i]) % mod;
                continue;
            }
            if (n - i >= 0 && (n + m - 2) / 2 - n + i >= 0)
            {
                ans = (ans + f[k][i - 1][n - i]) % mod;				
	    }
            if (n - i - 1 >= 0 && (n + m - 2) / 2 - n + i + 1 >= 0)
            {
                ans = (ans + f[k][i - 1][n - i - 1]) % mod;					
	    }		
        }
    }
	
    printf("%lld", ans);
	
    return 0;
}

AcWing.732 过河

\(f[i]\) 表示第 \(i\) 个人到第 \(n\) 个人运过去花的时间

每次过三个人,回两个人...最后一趟过三个人

分情况:用两个最小的数运一个最大数、回来的就是这两个小数,用一个小数运两个大数(需要保证对面有一个小数,跟着回来),运过去三个数,都留在对面,对面选两个回来(需要保证对面至少有两个数)

#include <iostream>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

using std::min;

int a[N];
int n;
long long f[N];

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n);

        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &a[i]);
        }
        std::sort(a + 1, a + n + 1);

        if (n <= 3)
        {
            printf("%d\n", a[n]);
            continue;
        }
        f[n + 1] = 0;
        
        for (rint i = n; i >= 4; i--)
        {
            //第一种情况,把第i个人运过去
            f[i] = a[i] + a[2] + f[i + 1];
            
            //第二种情况,把第i个人和第i+1个人运过去
            if (i + 1 <= n)
            {
                f[i] = min(f[i], a[3] + a[2] + a[i+1] + a[3] + f[i+2]);				
	    }
			
            //第三种情况,把第i、i+1、i+2个人运过去
            if (i + 2 <= n)
            {
                f[i]=min(f[i],a[3]+a[2]+a[4]+a[2]+a[i+2]+a[4]+f[i+3]);	
	    }
        }
        printf("%lld\n", f[4] + a[3]);
    }
    
    return 0;
}

AcWing.273 分级

对单调上升和单调下降分别求一遍,取最小值。

结论:一定存在一组最优解,使得每个 \(B_i\) 都是原序列中的某个值。

f[i][j] 是给 A[1] ~ A[i] 分配好了值且最后一个 B[i] = A'[j] 中所有方案的最小值

据倒数第二个数分配的是哪个 A'[i]f[i][j] 所代表的集合划分成j个不重不漏的子集:

倒数第二个数选取的是 A'[1] 的所有方案的结合,最小值是 f[i - 1][1] + abs(A[i] - A'[j]);

倒数第二个数选取的是 A'[2] 的所有方案的结合,最小值是 f[i - 1][2] + abs(A[i] - A'[j]);

倒数第二个数选取的是 A'[j] 的所有方案的结合,最小值是 f[i - 1][j] + abs(A[i] - A'[j]);

f[i][j] 在所有子集的最小值中取 min 即可。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e3 + 5;

int n;
int a[N], b[N];
int f[N][N];

int work()
{
    for (rint i = 1; i <= n; i++) 
    {
        b[i] = a[i];
    }
    
    std::sort(b + 1, b + n + 1);

    for (rint i = 1; i <= n; i++)
    {
        int minn = 1e9;
        for (rint j = 1; j <= n; j++)
        {
            minn = std::min(minn, f[i - 1][j]);
            f[i][j] = minn + std::abs(a[i] - b[j]);
        }
    }

    int res = 1e9;
    
    for (rint i = 1; i <= n; i++) 
    {
        res = std::min(res, f[n][i]);
    }

    return res;
}

int main()
{
    scanf("%d", &n);
    
    for (rint i = 1; i <= n; i++) 
    {
        scanf("%d", &a[i]);
    }

    int res = work();
    std::reverse(a + 1, a + n + 1);
    res = std::min(res, work());

    printf("%d\n", res);

    return 0;
}

S2OJ#.896 折射

\(dp[i][0/1]\) 表示按 \(x\) 排序,以 \(i\) 为入射点的光线,向左下方 /右下方入射

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int 
#define endl '\n'

const int N = 6e3 + 5;
const int mod = 1e9 + 7;

int n;
long long f[N][2];

struct node  
{
    long long x, y;
} a[N];

bool operator<(node a, node b) // 按 x 排列 
{
    return a.x < b.x;
}

signed main()
{
    scanf("%d", &n);
    
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld%lld", &a[i].x, &a[i].y);		
    }

    std::sort(a + 1, a + n + 1);
    
    for(rint i = 1; i <= n; i++)
    {
	f[i][0] = 1;
	f[i][1] = 1;
    }
    
    for (rint i = 1; i <= n; i++)
    {
        for (rint j = i - 1; j >= 1; j--)
	{
            if (a[j].y < a[i].y)
            {
                f[i][0] = (f[i][0] + f[j][1]) % mod;					
	    }
	    else
	    {
	        f[j][1] = (f[j][1] + f[i][0]) % mod;
	    }    
        }
    }
    
    long long ans = 0;
    
    for (rint i = 1; i <= n; i++)
    {
	ans = ans + f[i][0] + f[i][1] - 1;
	ans %= mod;
    }
    
    printf("%lld", ans % mod);
    
    return 0;
}

s2OJ#1450. 大相公

一个特别恶心的 dfs 的动态规划,不知道放在哪里,就放在线性 dp 里吧;

\(f_{i,j,k}\) 表示第 \(i\) 位(也就是第 \(i\) 种数字),上一位剩了 \(j\) 个,这一位剩了 \(k\) 个(全都是给顺子剩的)

然后需要加个小剪枝:然后你发现如果 \(f_i\) 的所有状态都为 0,那后面不能是 1 ,直接 return

更多内容见郭队 S2OJ Blog

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int M = 10;
const int mod = 998244353;

int n, len, p[M], ans;
bool f[M][3][3];
char s[N];

int add(int x, int y)
{
    return (x + y >= mod ? x + y - mod : x + y);
}

int mul(int x, int y)
{
    return 1ll * x * y % mod;
}

int qpow(int a, int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if (b & 1)
        {
            res = mul(res, a);			
		}
        a = mul(a, a);
    }
    return res;
}

int c[N][10];

int C(int n, int m){
    if(m == 0 || m == n)  
	return 1;
    if(c[n][m] != 0) 
	return c[n][m];
    return c[n][m] = add(C(n - 1, m) , C(n - 1, m - 1));
}

void dfs(int i, int u, int t)
{
    if (i > 9)
    {
        u = p[0] - u;
        if (u < 0)
        {
            return;
        }
        if (u > 0 && t == 0)
        {
            return;
        }
        if (f[9][0][0])
        {
            ans = add(ans, C(u + t - 1, t - 1));
        }
        return;
    }

    if (p[i] > 4)
    {
        f[i][0][0] |= f[i - 1][0][0];
        f[i][0][1] |= f[i - 1][0][0];
        f[i][1][1] |= f[i - 1][0][1];
        f[i][0][0] |= f[i - 1][1][1];
        f[i][0][1] |= f[i - 1][1][1];
        f[i][1][1] |= f[i - 1][1][2];
        f[i][0][2] |= f[i - 1][0][0];
        f[i][1][2] |= f[i - 1][0][1];
        f[i][0][2] |= f[i - 1][1][1];
        f[i][1][2] |= f[i - 1][1][2];

        if (f[i][0][0] == 0 && f[i][0][1] == 0 && f[i][1][1] == 0 && f[i][1][2] == 0)
        {
            return;
        }
        dfs(i + 1, u, t + 1);
        f[i][0][0] = f[i][0][1] = f[i][1][1] = f[i][1][2] = 0;
        return;
    }

    for (rint j = p[i]; j <= 5; j++)
    {
        if (j == 0)
        {
            f[i][0][0] |= f[i - 1][0][0];
        }
        else if (j == 1)
        {
            f[i][0][1] |= f[i - 1][0][0];
            f[i][1][1] |= f[i - 1][0][1];
            f[i][0][0] |= f[i - 1][1][1];
        }
        else if (j == 2)
        {
            f[i][0][2] |= f[i - 1][0][0];
            f[i][1][2] |= f[i - 1][0][1];
            f[i][0][1] |= f[i - 1][1][1];
            f[i][1][1] |= f[i - 1][1][2];
            f[i][0][0] |= f[i - 1][0][0];
        }
        else if (j == 3)
        {
            f[i][0][0] |= f[i - 1][0][0];
            f[i][0][1] |= f[i - 1][0][0];
            f[i][1][1] |= f[i - 1][0][1];
            f[i][0][0] |= f[i - 1][1][1];
            f[i][0][2] |= f[i - 1][1][1];
            f[i][1][2] |= f[i - 1][1][2];
        }
        else if (j == 4)
        {
            f[i][0][0] |= f[i - 1][0][0];
            f[i][0][1] |= f[i - 1][0][0];
            f[i][1][1] |= f[i - 1][0][1];
            f[i][0][0] |= f[i - 1][1][1];
            f[i][0][1] |= f[i - 1][1][1];
            f[i][1][1] |= f[i - 1][1][2];
            f[i][0][2] |= f[i - 1][0][0];
            f[i][1][2] |= f[i - 1][0][1];
        }
        else if (j == 5)
        {
            f[i][0][0] |= f[i - 1][0][0];
            f[i][0][1] |= f[i - 1][0][0];
            f[i][1][1] |= f[i - 1][0][1];
            f[i][0][0] |= f[i - 1][1][1];
            f[i][0][1] |= f[i - 1][1][1];
            f[i][1][1] |= f[i - 1][1][2];
            f[i][0][2] |= f[i - 1][0][0];
            f[i][1][2] |= f[i - 1][0][1];
            f[i][0][2] |= f[i - 1][1][1];
            f[i][1][2] |= f[i - 1][1][2];
        }
        if (f[i][0][0] == 0 && f[i][0][1] == 0 && f[i][1][1] == 0 && f[i][1][2] == 0)
        {
            continue;
        }
        dfs(i + 1, u + j - p[i], j == 5 ? t + 1 : t);
        f[i][0][0] = f[i][0][1] = f[i][1][1] = f[i][1][2] = 0;
    }
}

int main()
{
    scanf("%d", &n);
    scanf("%s", s);
    
    len = strlen(s);
    
    for (rint i = 0; i < len; i++)
    {
        if (s[i] == '?')
        {
	    p[0]++;
	}
        else
        {
            p[s[i] - 48]++;			
	}	
    }
            
    f[0][0][0] = 1;
    
    dfs(1, 0, 0);
    
    printf("%d\n", ans);
    
    return 0;
}

Part2.背包 dp

背包 dp 解决的问题就是有一个容积一定的背包,我们要往里放东西,每个物品都有自己的价值,在一些其他元素的影响下,保证放进背包物品的价值最大。

01 背包

\(N\) 件物品和一个容量是 \(V\) 的背包。每件物品只能使用一次。

\(i\) 件物品的体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大。

首先考虑二维解法,状态 f[i][j] 定义:前 i 个物品,背包容量 j 下的最大价值:

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int f[N][N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 0; j <= m; j++)
        {
            f[i][j] = f[i - 1][j];
            if (j >= v[i])
            {
                f[i][j] = std::max(f[i][j], f[i - 1][j - v[i]] + w[i]);
            }
        }
    }

    printf("%d", f[n][m]);

    return 0;
}

考虑优化空间,状态 f[j] 定义:N 件物品,背包容量 j 下的最优解。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int f[N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = m; j >= v[i]; j--)
        {
            f[j] = std::max(f[j], f[j - v[i]] + w[i]);
        }
    }

    printf("%d", f[m]);

    return 0;
}

完全背包问题

\(N\) 种物品和一个容量是 \(V\) 的背包,每种物品都有无限件可用。

\(i\) 种物品的体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大。

状态设置同 01 背包

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int f[N][N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 0; j <= m; j++)
        {
            for (int k = 0; k * v[i] <= j; k++)
            {
                f[i][j] = std::max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i]);
            }
        }
    }

    printf("%d", f[n][m]);

    return 0;
}
#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int f[N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = v[i]; j <= m; j++)
        {
            f[j] = std::max(f[j], f[j - v[i]] + w[i]);
        }
    }

    printf("%d", f[m]);

    return 0;
}

多重背包问题

\(N\) 种物品和一个容量是 \(V\) 的背包。

\(i\) 种物品最多有 \(s_i\) 件,每件体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使物品体积总和不超过背包容量,且价值总和最大。

首先是暴力解法

状态 f[i][j] 定义:前 i 个物品,背包容量 j 下的最大价值:

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int f[N][N];
int s[N], v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
        read(s[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 0; j <= m; j++)
        {
            for (int k = 0; k <= s[i]; k++)
            {
                if(j >= k * v[i])
                {
                    f[i][j] = std::max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i]);                    
                }
            }
        }
    }

    printf("%d", f[n][m]);

    return 0;
}

之后可以二进制拆分优化

二进制拆分原理:任何一个整数都可以转换成一个若干个2k数相加的形式(因为可以转化成二进制数)。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 2e3 + 5;

int f[N], v[N], w[N];
int n, m;

int main()
{
    read(n);
    read(m);

    int idx = 0;
    for (rint i = 1; i <= n; i++)
    {
        int a, b, s;
        read(a);
        read(b);
        read(s);
        int k = 1;
        while (k <= s)
        {
            idx++;
            v[idx] = k * a;
            w[idx] = k * b;
            s -= k;
            k <<= 1;
        }
        if (s > 0)
        {
            idx++;
            v[idx] = a * s;
            w[idx] = b * s;
        }
    }

    for (rint i = 1; i <= idx; i++)
    {
        for (rint j = m; j >= v[i]; j--)
        {
            f[j] = std::max(f[j], f[j - v[i]] + w[i]);
        }
    }

    printf("%d", f[m]);

    return 0;
}

最后就是单调队列优化 dp

#include <iostream>
#include <cstring>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 2e4 + 5;

int n, m;
int f[N], g[N], q[N];
int v[N], s[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
        read(s[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        memcpy(g, f, sizeof f);

        for (rint j = 0; j < v[i]; j++)
        {
            int hh = 0, tt = -1;
            for (rint k = j; k <= m; k += v[i])
            {
                if (hh <= tt && q[hh] < k - s[i] * v[i])
                {
                    hh++;
                }
                if (hh <= tt)
                {
                    f[k] = std::max(f[k], g[q[hh]] + (k - q[hh]) / v[i] * w[i]);
                }
                while (hh <= tt && g[q[tt]] - (q[tt] - j) / v[i] * w[i] <= g[k] - (k - j) / v[i] * w[i])
                {
                    tt--;
                }
                q[++tt] = k;
            }
        }
    }

    printf("%d", f[m]);

    return 0;
}

混合背包问题

\(N\) 种物品和一个容量是 \(V\) 的背包。

物品一共有三类:

第一类物品只能用 1 次(01背包);
第二类物品可以用无限次(完全背包);
第三类物品最多只能用 \(s_i\) 次(多重背包);
每种体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使物品体积总和不超过背包容量,且价值总和最大。

第一行两个整数,\(N\)\(V\),用空格隔开,分别表示物品种数和背包容积。

接下来有 \(N\) 行,每行三个整数 \(v_i,w_i,s_i\),用空格隔开,分别表示第 i 种物品的体积、价值和数量。

\(s_i=−1\) 表示第 \(i\) 种物品只能用 \(1\) 次;
\(s_i=0\) 表示第 \(i\) 种物品可以用无限次;
\(s_i>0\) 表示第 \(i\) 种物品可以使用 \(s_i\) 次;

二进制拆分即可

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e5 + 5;

int n, m;
int f[N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    int idx = 0;
    for (rint i = 1; i <= n; i++)
    {
        int a, b, s;
        read(a);
        read(b);
        read(s);
        int k = 1;
        if (s < 0)
        {
            s = 1;
        }
        if (s == 0)
        {
            s = m / a;
        }
        while (k <= s)
        {
            idx++;
            v[idx] = k * a;
            w[idx] = k * b;
            s -= k;
            k <<= 1;
        }
        if (s > 0)
        {
            idx++;
            v[idx] = a * s;
            w[idx] = b * s;
        }
    }

    for (rint i = 1; i <= idx; i++)
    {
        for (rint j = m; j >= v[i]; j--)
        {
            f[j] = std::max(f[j], f[j - v[i]] + w[i]);
        }
    }

    printf("%d", f[m]);

    return 0;
}

二维费用的背包问题

\(N\) 件物品和一个容量是 \(V\) 的背包,背包能承受的最大重量是 \(M\)

每件物品只能用一次。体积是 \(v_i\),重量是 \(m_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使物品总体积不超过背包容量,总重量不超过背包可承受的最大重量,且价值总和最大。

二维费用背包 dp 可谓背包 dp 方面最常考的问题,打上 TAG !!!QWQ

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, t, p;
int f[N][N];
int v[N], m[N], w[N];

int main()
{
    read(n);
    read(t);
    read(p);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(m[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = p; j >= m[i]; j--)
        {
            for (rint k = t; k >= v[i]; k--)
            {	
                f[j][k] = std::max(f[j][k], f[j - m[i]][k - v[i]] + w[i]);			
	    }			
	}		
    }

    printf("%d", f[p][t]);

    return 0;
}

分组背包问题

\(N\) 组物品和一个容量是 \(V\) 的背包。

每组物品有若干个,同一组内的物品最多只能选一个。
每件物品的体积是 \(v_{ij}\),价值是 \(w_{ij}\),其中 \(i\) 是组号,\(j\) 是组内编号。

求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;

int n,m;
int v[N][N],w[N][N],s[N];
int f[N];

int main()
{
    scanf("%d%d", &n, &m);
    
    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &s[i]);
        for (rint j = 1; j <= s[i]; j++)
	{
            scanf("%d%d", &v[i][j], &w[i][j]);
        }
    }
    
    for (rint i = 1; i <= n; i++)
    {
        for(rint j = m; j >= 0; j--)
	{
            for(rint k = 0; k <= s[i]; k++)
	    {
                if(j >= v[i][k])
		{
                    f[j] = std::max(f[j], f[j - v[i][k]] + w[i][k]);
                }
            }
        }
    }
    
    printf("%d", f[m]);
    
    return 0;
}

背包问题求方案数

\(N\) 件物品和一个容量是 \(V\) 的背包。每件物品只能使用一次。

\(i\) 件物品的体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大。

输出 最优选法的方案数。注意答案可能很大,请输出答案模 \(10^9+7\) 的结果。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e5 + 5;
const int mod = 1e9 + 7;

int n, m;
int f[N], g[N];
int v[N], w[N];

int main()
{
    read(n);
    read(m);

    for (rint i = 0; i <= m; i++)
    {
        g[i] = 1;
    }

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(w[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = m; j >= v[i]; j--)
        {
            if (f[j] < f[j - v[i]] + w[i])
            {
                g[j] = g[j - v[i]];
                f[j] = f[j - v[i]] + w[i];
            }
            else if (f[j] == f[j - v[i]] + w[i])
            {
                g[j] = (g[j] + g[j - v[i]]) % mod;
            }
        }
    }

    printf("%d", g[m]);

    return 0;
}

背包问题求具体方案

\(N\) 件物品和一个容量是 \(V\) 的背包。每件物品只能使用一次。

\(i\) 件物品的体积是 \(v_i\),价值是 \(w_i\)

求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大。

输出 字典序最小的方案。这里的字典序是指:所选物品的编号所构成的序列。物品的编号范围是 \(1…N\)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e3 + 5;

int n, m;
int w[N], v[N];
int f[N][N];
int path[N];

int main()
{
    read(n);
    read(m);
    
    for (rint i = 1; i <= n; i++)
    {
	read(v[i]);
	read(w[i]);
    }

    for (rint i = n; i >= 1; i--)
    {
        for (rint j = 0; j <= m; j++)
        {
            f[i][j] = f[i + 1][j];
            if (j >= v[i])
            {
                f[i][j] = std::max(f[i][j], f[i + 1][j - v[i]] + w[i]);				
	    }
        }
    }
    
    int idx = 0;
    for (rint i = 1, j = m; i <= n; i++)
    {
        if (j >= v[i] && f[i][j] == f[i + 1][j - v[i]] + w[i])
        {
            path[++idx] = i;
            j -= v[i];
        }
    }
    
    for (rint i = 1; i <= idx; i++)
    {
	printf("%d ",path[i]);
    }
    
    return 0;
}

[洛谷月赛] 小挖的买花

非常非常水的一个二维费用背包 dp

但是赛时少了亿点特判,导致 WA + MLE

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d",&x)

const int N = 5e2 + 5;
const int M = 500;

int n, q, a, b;
int v[N], f[N], w[N];
int dp[N][N];

int main()
{
    read(n);
    read(q);

    for (rint i = 1; i <= n; i++)
    {
        read(v[i]);
        read(f[i]);
        read(w[i]);
    }

    memset(dp, 0xc0, sizeof dp);
    dp[0][0] = 0;

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = M; j >= 0; j--)
        {
            for (rint k = M; k >= v[i]; k--)
            {
                dp[j][k] = std::max(dp[j][k], dp[std::max(0, j - f[i])][k - v[i]] + w[i]);
            }
            for (int k = 1; k <= M; k++)
            {
                dp[j][k] = std::max(dp[j][k], dp[j][k - 1]);
            }
        }
    }

    while (q--)
    {
        int a, b;
        read(a);
        read(b);
        if (dp[b][a] > 0)
        {
            printf("%d\n",dp[b][a]);
        }
        else
        {
            puts("0");
        }
    }

    return 0;
}

[USACB]Stock Market

假设每天买完第二天就卖掉,如果第二天不卖掉可以看成是第二天卖掉再买入。这样就把问题转化成了完全背包问题。一天一天DP即可。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e2 + 5;
const int M = 1e7 + 5;

int n, t, m;
int a[N][N];
int f[M];

int main()
{
    read(n);
    read(t);
    read(m);
	
    for(rint i = 1; i <= n; i++)
    {
	for(rint j = 1; j <= t; j++)
	{
		read(a[i][j]);
	}
    }
	
    for(rint i = 1; i <= t; i++)
    {
	memset(f, 0, sizeof f);
	for(rint j = 1; j <= n; j++)
	{
	    for(rint k = a[j][i]; k <= m; k++)
	    {
		f[k] = std::max(f[k], f[k - a[j][i]] + a[j][i + 1] - a[j][i]);
	    }
	}
        m += f[m];
    }
	
    printf("%d", m);
	
    return 0;
}

P1156 垃圾陷阱

f[i] 表示高度堆到 i 时的最大生命值。

代码里面有注释。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

#define read(x) scanf("%d", &x)

const int N = 1e4 + 5;

struct node
{
    int t, h, f;
} a[N];

bool operator<(node a, node b)
{
    return a.t < b.t;
}

int hight, n;
int f[N];

int main()
{
    read(hight);
    read(n);

    for (rint i = 1; i <= n; i++)
    {
        read(a[i].t);
        read(a[i].f);
        read(a[i].h);
    }

    std::sort(a + 1, a + n + 1);
    f[0] = 10;

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = hight; j >= 0; j--)
        {
            if (f[j] >= a[i].t)  //如果当前垃圾掉落时的生命值不小于掉落时间
            {
                if (j + a[i].h >= hight) //并且如果这个垃圾的堆叠高度恰好等于或大于坑的高度
                {
                    printf("%d", a[i].t);  //输出当前时间
                    return 0;
                }
                f[j + a[i].h] = std::max(f[j + a[i].h], f[j]);  ///否则这个高度+这个垃圾的高度等于从0到hight的生命值
                f[j] += a[i].f;           //吃垃圾,续命
            }
        }
    }

    printf("%d", f[0]);

    return 0;
}

P4141 消失之物

背包求解方案数板子题;

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 5e3 + 5;

int n, m;
int w[N], f[N], g[N];

int main()
{
    scanf("%d%d", &n, &m);
    
    for (rint i = 1; i <= n; i++)
    {
	scanf("%d", &w[i]);
    }
	
    f[0] = 1;
    g[0] = 1;
    
    for (int i = 1; i <= n; i++)
    {
        for (int j = m; j >= w[i]; j--)
        {
            f[j] = (f[j] + f[j - w[i]]) % 10;			
	}		
    }      
	      
    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 1; j <= m; j++)
        {
            if (w[i] > j)
            {
                g[j] = f[j] % 10;				
	    }
            else
            {
                g[j] = (f[j] - g[j - w[i]] + 10) % 10;				
	    }
            printf("%d", g[j]);
        }
        puts("");
    }
    
    return 0;
}

P2079 烛光晚餐

这个题很显然可以发现,存在负数的情况。

怎么办???

没事儿,数组开大一点,背包正常写,我们发现负数最小为 -500, 所以每次调用数组的时候加个 500 就行了。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 1e3 + 5;
const int t = 500;

int n, m;
int f[N][N];
int c[N], w[N], v[N];

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d%d%d", &c[i], &v[i], &w[i]);
    }

    memset(f, 0xcf, sizeof f);
    f[0][t] = 0;

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = m; j >= c[i]; j--)
        {
            for (rint k = t; k >= -t; k--)
            {
                if (abs(k - c[i] <= t) <= t)
                {
                    f[j][k + t] = std::max(f[j][k + t], f[j - c[i]][k - v[i] + t] + w[i]);
                }
            }
        }
    }

    int ans = -inf;

    for (rint i = 0; i <= m; i++)
    {
        for (rint j = 0; j <= t; j++)
        {
            ans = std::max(ans, f[i][j + t]);
        }
    }

    printf("%d", ans);

    return 0;
}

P4832 珈百璃堕落的开始

披着数学的皮的背包问题。

众所周知,\(sin^2α + cos^2α = 1\),题目中说答案必须为整数,所以s和c的数量必须相等,以凑成若干对。

转变为背包问题:每个式子抽象成物品,s−c 抽象成重量,而我们的目标是找到总重量为 0 时的最大价值,价值抽象成 s(或者 c),因为每一对配对的 sc 价值为 1

然后正常背包 dp,和烛光晚餐那个题一样,先把数组开大,然后自己定义一个 \(t = 114514\),每次加上一个 \(t\) 就行。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 2e6 + 5;
const int t = 114514;

int n;
int f[2][N];
int w[N], v[N];
int cnt[3];
char s[N];

int main()
{
    scanf("%d", &n);

    memset(f, 0xcf, sizeof f);
    f[0][t] = 0;
    int l = 0, r = 0;

    for (rint i = 1; i <= n; i++)
    {
        scanf("%s", s + 1);
        int len = strlen(s + 1);
        cnt[1] = 0;
        cnt[2] = 0;
        
        for (rint j = 1; j <= len; j += 2)
        {
            if (s[j] == 'c')
                cnt[1]++;
            if (s[j] == 's')
                cnt[2]++;
        }

        w[i] = cnt[1];
        v[i] = cnt[2] - cnt[1];
        l = std::min(l, v[i] + l);
        r = std::max(r, v[i] + r);

        for (rint j = l; j <= r; j++)
        {
            f[i & 1][j + t] = std::max(f[i & 1][j + t], f[i & 1 ^ 1][j + t]);
            f[i & 1][j + t] = std::max(f[i & 1][j + t], f[i & 1 ^ 1][j - v[i] + t] + w[i]);
        }
    }

    printf("%d", f[n & 1][t]);

    return 0;
}

CF19B Assistant

01 背包的变式。

完成任务是使所有物品全部买或偷到,而 Bob 有多少扫描时间便能偷多少物品,所以扫描了某一物品,能带走的物品便是扫描时间,加上 1 (也就是你正在扫描的这个物品),在这里可以直接把 扫描时间++,那么现在扫描时间就等价与能带走的物品个数了。

将扫描时间看作体积,价格看作价值。那么题目便等价与把背包体积至少装至 \(n\) 的最小价值,就成为 01 背包板子题了。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int inf = 1e18;

int n, m;
int f[N];
int v[N], w[N];
int ans = inf;

signed main()
{
    scanf("%lld", &n);
	
    for(rint i = 1; i <= n; i++)
    {
	scanf("%lld%lld", &v[i], &w[i]);
	v[i]++;
	m = std::max(m, v[i]);
    }
	
    memset(f, 0x7f, sizeof f);
    f[0] = 0;
    m += n;
	
    for(rint i = 1; i <= n; i++)
    {
        for(rint j = m; j >= v[i]; j--)
	{
            f[j] = std::min(f[j], f[j - v[i]] + w[i]);
	}
    }
	
    for(rint i = n; i <= m; i++)
    {
	ans = std::min(f[i], ans);
    }
	
    printf("%lld", ans);
	
    return 0;
}

CF632E Thief

将题目转化为:求凑某个权值和最少要多少物品,如果最少个数比 \(k\) 要大,则这个权值和就无法凑出。

然后背包 dp 就行了。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int inf = 0x3f3f3f3f;

int n, m, k;
int v[N];
int f[N];
int minn = inf;

int main()
{
    scanf("%d%d", &n, &k);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &v[i]);
        minn = std::min(v[i], minn);
        m = std::max(v[i], m);
    }
    m -= minn;

    for (rint i = 1; i <= n; i++)
    {
        v[i] -= minn;
    }

    memset(f, 0x3f, sizeof f);
    f[0] = 0;

    for (rint i = 1; i <= k * m; i++)
    {
        for (rint j = 1; j <= n; j++)
        {
            if (v[j] <= i)
            {
                f[i] = std::min(f[i], f[i - v[j]] + 1);
            }
        }
    }

    for (rint i = 0; i <= k * m; i++)
    {
        if (f[i] <= k)
        {
            printf("%d ", i + k * minn);
        }
    }

    return 0;
}

CF577B Modulo Sum

小学奥数:抽屉原理

把多于 \(n\) 个的苹果放到 \(n\) 个抽屉里,则至少有一个抽屉里的苹果不少于两个。

所以 \(n ≥ m\) 直接输出 YES

如果不是,\(n\) 的数据范围就会不超过 \(1000\);

这个时候就可以直接背包 dp 了;

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5;

bool f[N][N], flag;
int n, m;
int v[N];

int main()
{
    scanf("%d%d", &n, &m);

    if (n > m)
    {
        puts("YES");
        return 0;
    }

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &v[i]);
        v[i] %= m;
    }

    for (rint i = 1; i <= n; i++)
    {
        if (flag)
        {
            continue;
        }
        f[i][v[i]] = 1;
        for (rint j = 1; j <= m; j++)
        {
            f[i][j] |= f[i - 1][j];
            f[i][(j + v[i]) % m] |= f[i - 1][j];
        }
        flag |= f[i][0];
    }

    if (flag)
    {
        puts("YES");
        return 0;
    }

    puts("NO");
    return 0;
}

Part3.区间 dp

区间 dp 常见的状态设计为 :

  • f[l][r] 表示在 \(l-r\) 区间内的答案

  • f[l][r][0/1] 表示在 \(l-r\) 区间落脚点在 左端点/右端点 内的答案

  • f[l][r][k] 表示在 \(l-r\) 区间内选择 \(k\) 个东西的答案

常见转移方程:

\[f[l][r] = min/max(f[l][r], f[l][k] + f[k + 1][r]) \]

[NOI1995] 石子合并

定义状态 f[l][r],表示 lr 合并后的最大得分。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 4e2 + 5;
const int inf = 0x3f3f3f3f;

int f[N][N];
int n;
int a[N], s[N];

int cost(int l, int r)
{
    return s[r] - s[l - 1];
}

int dp_min()
{
    int minn = inf;
    for(rint len = 1; len < n; len++)
    {
    	for(rint l = 1; l <= 2 * n - len; l++)
    	{
    	    int r = l + len;
    	    f[l][r] = inf;
    	    for(rint k = l; k < r; k++)
    	    {
    	        f[l][r] = std::min(f[l][r],f[l][k] + f[k + 1][r] + cost(l, r));
	    }
	    if(len + 1 == n)
	    {
		minn = std::min(minn, f[l][r]);
	    }
        }
    }
    return minn;
}

int dp_max()
{
    int maxx = 0;
    for(rint len = 1; len < n; len++)
    {
	for(rint l = 1;l <= 2 * n - len; l++)
	{
	    int r = l + len;
	    f[l][r] = 0;
	    for(rint k = l; k < r; k++)
	    {
		f[l][r] = std::max(f[l][r], f[l][k] + f[k + 1][r] + cost(l, r));
	    }
	    if(len + 1 == n)
	    {
		maxx = std::max(maxx, f[l][r]);
	    }
        }
    }
    return maxx;
}

int main()
{
    scanf("%d", &n);
    
    for(rint i = 1; i <= n * 2; i++)
    {
    	if(i <= n)
	{
	    scanf("%d", &a[i]);
	    a[i + n] = a[i];
        }
    	s[i] = s[i - 1] + a[i];
    }
	
    printf("%d\n%d", dp_min(), dp_max());
	
    return 0;
}

[CQOI2007] 涂色

f[l][r] 为字符串的子串 s[l]~s[r] 的最少染色次数

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;

char s[N];
int n;
int f[N][N];

int main()
{
    scanf("%s", s + 1);
    n = strlen(s + 1);
    memset(f, 0x3f, sizeof f);

    for (rint i = 1; i <= n; i++)
    {
        f[i][i] = 1;
    }

    for (rint len = 1; len < n; len++)
    {
        for (rint l = 1; l <= n - len; l++)
        {
            int r = l + len;
            if (s[l] == s[r])
            {
                f[l][r] = std::min(f[l + 1][r], f[l][r - 1]);
            }
            else
            {
                for (rint k = l; k < r; k++)
                {
                    f[l][r] = std::min(f[l][r], f[l][k] + f[k + 1][r]);
                }
            }
        }
    }

    printf("%d", f[1][n]);

    return 0;
}

CF245H Palindromes

f[l][r] 表示在 lr 的这个区间内的回文串数量

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e3 + 5;

int n, T;
int f[N][N];
char s[N];
int is[N][N];

bool cost(int l, int r)
{
    if (is[l][r] != -1)
    {
        return is[l][r]; //之前搜索过的状态无需再搜
    }
    if (l == r)
    {
        is[l][r] = 1;
        return true;
    }
    if (l + 1 == r)
    {
        if (s[l] == s[r])
        {
            is[l][r] = 1;
            return true;
        }
        else
        {
            is[l][r] = 0;
            return false;
        }
    }
    if (s[l] != s[r])
    {
        is[l][r] = 0;
        return false;
    }
    is[l][r] = cost(l + 1, r - 1);
    return is[l][r];
}

int main()
{
    memset(is, -1, sizeof is);
    scanf("%s", s + 1);
    scanf("%d", &T);
    n = strlen(s + 1);

    for (rint i = 1; i <= n; i++)
    {
        f[i][i] = 1;
    }

    for (rint i = 1; i < n; i++)
    {
        if (s[i] == s[i + 1])
        {
            f[i][i + 1] = 3;
        }
        else
        {
            f[i][i + 1] = 2;
        }
    }

    for (rint len = 2; len < n; len++)
    {
        for (rint l = 1; l <= n - len; l++)
        {
            int r = l + len;
            f[l][r] = f[l + 1][r] + f[l][r - 1] - f[l + 1][r - 1] + cost(l, r);
        }
    }

    while (T--)
    {
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%d\n", f[l][r]);
    }

    return 0;
}

[IOI1998] Polygon

f[i][j] 表示 [i,j] 这一个区间内可以得到的最大得分

但是,如果直接正常写会挂掉 20pts

为什么呢,举个例子,两个极小的负数相乘是可以大于两个小正整数相乘的。

所以需要同时维护最小值

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

#define min(a, b) std::min(a, b)
#define max(a, b) std::max(a, b)

const int inf = 0x3f3f3f3f;
const int N = 1e2 + 5;

int ans = -inf;
int n;
int a[N];
int f1[N][N], f2[N][N];
char b[N];

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        std::cin >> b[i] >> a[i];
        a[i + n] = a[i];
        b[i + n] = b[i];
    }

    memset(f1, 0xcf, sizeof f1);
    memset(f2, 0x3f, sizeof f2);

    for (rint i = 1; i <= 2 * n; i++)
    {
        f1[i][i] = f2[i][i] = a[i];
    }

    for (rint len = 1; len < 2 * n; len++)
    {
        for (rint l = 1; l <= 2 * n - len; l++)
        {
            int r = l + len;
            for (rint k = l; k < r; k++)
            {
                if (b[k + 1] == 't')
                {
                    f1[l][r] = max(f1[l][r], f1[l][k] + f1[k + 1][r]);
                    f2[l][r] = min(f2[l][r], f2[l][k] + f2[k + 1][r]);
                }
                if (b[k + 1] == 'x')
                {
                    f1[l][r] = max(f1[l][r], max(f1[l][k] * f1[k + 1][r], 
                                                 f2[l][k] * f2[k + 1][r]));
                    f2[l][r] = min(f2[l][r], min(f1[l][k] * f1[k + 1][r], 
                               min(f2[l][k] * f2[k + 1][r], 
                               min(f1[l][k] * f2[k + 1][r], 
                                   f2[l][k] * f1[k + 1][r]))));
                }
            }
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        ans = max(ans, f1[i][i + n - 1]);
    }

    printf("%d\n", ans);

    for (rint i = 1; i <= n; i++)
    {
        if (ans == f1[i][i + n - 1])
        {
            printf("%d ", i);
        }
    }

    return 0;
}

P1220 关路灯

f[i][j] 记为当从 ij 的灯都熄灭后剩下的灯的总功率。

但还不够,继续延伸

f[i][j][0] 表示关掉 ij 的灯后,老张站在 i 端点,f[i][j][1] 表示关掉 ij 的灯后,老张站在右端点 (i 为左端点,j 为右端点)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e1 + 5;

int n, c;
int f[N][N][2];
int p[N], a[N];
int s[N];

int cost(int i, int j, int l, int r)
{
    return (p[j] - p[i]) * (s[n] + s[l] - s[r - 1]);
}

int main()
{
    scanf("%d%d", &n, &c);

    memset(f, 0x3f, sizeof f);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d%d", &p[i], &a[i]);
        s[i] = s[i - 1] + a[i];
    }

    f[c][c][0] = 0;
    f[c][c][1] = 0;

    for (rint r = c; r <= n; r++)
    {
        for (rint l = r - 1; l >= 1; l--)
        {
            f[l][r][0] = std::min(f[l + 1][r][0] + cost(l, l + 1, l, r + 1), 
                                  f[l + 1][r][1] + cost(l, r, l, r + 1));
            f[l][r][1] = std::min(f[l][r - 1][1] + cost(r - 1, r, l - 1, r), 
                                  f[l][r - 1][0] + cost(l, r, l - 1, r));
        }
    }

    int ans = std::min(f[1][n][1], f[1][n][0]);

    printf("%d", ans);

    return 0;
}

[ABC163E] Infants

f[l][r] 为该区间对答案的最大贡献值。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define int long long
#define rint register int
#define endl '\n'

const int N = 2e3 + 5;

struct node
{
    int v;
    int id;
} a[N];

bool operator<(node a, node b)
{
    return a.v < b.v;
}

int n;
int f[N][N];

signed main()
{
    scanf("%lld", &n);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i].v);
        a[i].id = i;
    }

    std::sort(a + 1, a + n + 1);

    for (rint i = 1; i <= n; i++)
    {
        f[i][i] = a[1].v * std::abs(a[1].id - i);
    }

    for (rint len = 1; len < n; len++)
    {
        for (rint l = 1; l <= n - len; l++)
        {
            int r = l + len;
            f[l][r] = std::max(f[l][r], f[l + 1][r] + a[r-l+1].v * std::abs(a[r-l+1].id - l));
            f[l][r] = std::max(f[l][r], f[l][r - 1] + a[r-l+1].v * std::abs(a[r-l+1].id - r));
        }
    }

    printf("%lld", f[1][n]);

    return 0;
}

[JSOI2007] 祖玛

f[i][j] 表示消除从第 i 个位置到第 j 个位置的序列需要的最少操作数

cnt[] 记录个数

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e3 + 5;

int n;
int a[N], cnt[N];
int idx;
int f[N][N];

int main()
{
    scanf("%d", &n);

    memset(f, 0x3f, sizeof f);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        if (i != 1 && a[i] == a[i - 1])
        {
            cnt[idx]++;
        }
        else
        {
            idx++;
            a[idx] = a[i];
            cnt[idx] = 1;
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        if (cnt[i] > 1)
        {
            f[i][i] = 1;
        }
        else
        {
            f[i][i] = 2;
        }
    }

    for (rint len = 1; len < idx; len++)
    {
        for (rint l = 1; l <= idx - len; l++)
        {
            int r = l + len;
            if (a[l] == a[r])
            {
                f[l][r] = std::min(f[l][r], f[l + 1][r - 1] + (cnt[l] + cnt[r] > 2 ? 0 : 1));
            }
            for (rint k = l; k < r; k++)
            {
                f[l][r] = std::min(f[l][r], f[l][k] + f[k + 1][r]);
            }
        }
    }

    printf("%d", f[1][idx]);

    return 0;
}

S2OJ#.1201

f[l][r] 表示在该区间内最少承受多少伤害

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <math.h>

#define int long long
#define rint register int
#define endl '\n'

const int N = 4e2 + 5;
const int inf = 0x7f7f7f7f;

int n, m;
int a[N], b[N], h[N];
int f[N][N];

signed main()
{
    scanf("%lld%lld", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld%lld%lld", &a[i], &b[i], &h[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
	f[i][i] = (h[i] + m - 1) / m * (a[i] + b[i - 1] + b[i + 1]);
    }

    for (rint len = 1; len < n; len++)
    {
	for (rint l = 1; l <= n - len; l++)
	{
	    int r = l + len;
	    f[l][r] = inf;
	    for (rint k = l; k <= r; k++)
	    {
		f[l][r] = std::min(f[l][r], f[l][k - 1] + f[k + 1][r] + (h[k] + m - 1) / m * (a[k] + b[l - 1] + b[r + 1]));
	    }
        }
    }

    printf("%lld", f[1][n]);

    return 0;
}

[HAOI2016] 字符合并

f[i][j][t] 表示原串中第 ij 个数字最终合并成 t 的状态的最大分数。

需要状态压缩优化一下,后边会进行复习。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

const int N = 3e2 + 5;
const int M = 256;
const int inf = 1e18;

int n, m;
int c[N], w[N], a[N];
int f[N][N][M];
int g[2];
int ans;

signed main()
{
    scanf("%lld%lld", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
	scanf("%lld", &a[i]);
    }

    for (rint i = 0; i < (1 << m); i++)
    {
	scanf("%lld%lld", &c[i], &w[i]);
    }

    memset(f, 0xcf, sizeof f);


    for (rint i = 1; i <= n; i++)
    {
	f[i][i][a[i]] = 0;
    }

    for (rint len = 1; len < n; len++)
    {
	for (rint l = 1; l + len <= n; l++)
	{
	    int r = l + len;
	    int x = len % (m - 1);

	    if (!x)
	    {
		x = m - 1;
	    }
	    for (rint k = r - 1; k >= l; k -= m - 1)
	    {
		for (rint p = 0; p < (1 << x); p++)
		{
		    f[l][r][p << 1] = std::max(f[l][r][p << 1], f[l][k][p] + f[k + 1][r][0]);
		    f[l][r][p << 1 | 1] = std::max(f[l][r][p << 1 | 1], f[l][k][p] + f[k + 1][r][1]);
		}
	    }

	    if (x == m - 1)
	    {
		g[0] = -inf;
		g[1] = -inf;
	        for (rint p = 0; p < (1 << m); p++)
	        {
		    int k = c[p];
		    g[k] = std::max(g[k], f[l][r][p] + w[p]);
		}
		f[l][r][0] = g[0];
		f[l][r][1] = g[1];
	    }
	}
    }

    for (rint i = 0; i < (1 << m); i++)
    {
	ans = std::max(ans, f[1][n][i]);
    }

    printf("%lld", ans);

    return 0;
}

Part4.数位 dp(递推)

[SCOI2009] windy 数

f[i][j] 为长度为 i 中最高位是 j 的 windy 数的个数

具体注释在代码里

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <math.h>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;

int n, m;
int f[N][N];
int a[N];

int query(int x)  //求出 <x 的 windy 数
{
    int idx = 0; //记录位数
    int ans = 0;

    while (x > 0)
    {
	a[++idx] = x % 10;
	x /= 10;
    }

    for (rint i = 1; i < idx; i++)
    // 不取等是因为一会长度为 idx 的情况另外计算
    //因为长度为 idx 的时候最高位不一定到 9
    {
        for (rint j = 1; j <= 9; j++)
	{
	    ans += f[i][j];
	}
    }

    for (rint i = 1; i < a[idx]; i++) //处理最高位 < a[idx] 的 windy 数
    {
	ans += f[idx][i];
    }

    for (rint i = idx - 1; i >= 1; i--)
    //最处理高位 = a[idx] 的 windy 数, 长度逐渐缩小,直到不合法为止
    {
	for (rint j = 0; j < a[i]; j++) //注意:此处不是最高位,可以为零
	{
	    if (std::abs(j - a[i + 1]) >= 2) //如果满足 windy 数要求
	    {
		ans += f[i][j];
	    }
        }
	if (std::abs(a[i] - a[i + 1]) < 2) //如果不合法
	{
	    break;
	}
    }

    return ans;
}

int main()
{
    for (rint i = 0; i <= 9; i++)
    {
	f[1][i] = 1;
    }

    for (rint i = 2; i <= 10; i++)
    {
	for (rint j = 0; j <= 9; j++)
	{
	    for (rint k = 0; k <= 9; k++)
	    {
		if (std::abs(j - k) >= 2)
		{
	            f[i][j] += f[i - 1][k];
		}
	    }
        }
    }

    scanf("%d%d", &n, &m);

    printf("%d", query(m + 1) - query(n));

    return 0;
}

[ZJOI2010] 数字计数

f[i][j][k] 表示第 i 位,最高位为 jk 有多少个。

首先我们可以得到一个粗略的方程:

\(f[i][j][k]=\sum\limits_{l=0}^{9}{f[i-1][l][k]}+\text{count}\)

\(count\)\(k\) 在这一位上出现的次数。

那么在这一位上出现了多少次呢?

如果 \(j > k\) \(or\) \(j < k\) ,很明显,皮蛋。

\(j = k\) 呢?

  • 比如 [10,19]1 就出现了 10 次。
  • [100][199]1 很明显出现了 100 次。

怎么来的?

\(10=10^1=10^{2-1}\)
\(100=10^2=10^{3-1}\)

\(f[i][j][j]+=10^{i-1}\)

剩下的就是数位 dp 模板了:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <math.h>

#define int unsigned long long
#define rint register int
#define endl '\n'

const int N = 1e1 + 5;

int f[N][N][N];
int b, c;
int a[N];

int query(int x, int num)
{
    int idx = 0;
    int ans = 0;

    while (x > 0)
    {
        a[++idx] = x % 10;
        x /= 10;
    }

    for (rint i = 1; i < idx; i++)
    {
        for (rint j = 1; j <= 9; j++)
        {
            ans += f[i][j][num];
        }
    }

    for (rint i = 1; i < a[idx]; i++)
    {
        ans += f[idx][i][num];
    }

    for (rint i = idx - 1; i >= 1; i--)
    {
        for (rint j = 0; j < a[i]; j++)
        {
            ans += f[i][j][num];
        }
        for (rint j = idx; j > i; j--)
        {
            if (a[j] == num)
            {
                ans += a[i] * pow(10, i - 1);
            }
        }
    }
    return ans;
}

signed main()
{
    scanf("%lld%lld", &b, &c);

    for (rint i = 0; i <= 9; i++)
    {
        f[1][i][i] = 1;
    }

    for (rint i = 2; i <= 13; i++)
    {
        for (rint j = 0; j <= 9; j++)
        {
            for (rint k = 0; k <= 9; k++)
            {
                for (rint l = 0; l <= 9; l++)
                {
                    f[i][j][l] += f[i - 1][k][l];
                }
            }
            f[i][j][j] += pow(10, i - 1);
        }
    }

    for (rint i = 0; i <= 9; i++)
    {
        printf("%lld ", query(c + 1, i) - query(b, i));
    }

    return 0;
}

AcWing.310 启示录

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

const int N = 2e1 + 1;

int n;
int f[N][4];
int g[N];
// f[i][j]: 由i个数构成的左边(高位)有j个6的数的个数
// g[i]: 由i个数构成魔鬼数的个数

signed main()
{
    f[0][0] = 1;

    for (rint i = 1; i < N; i++)
    {
        f[i][0] = 9 * (f[i - 1][0] + f[i - 1][1] + f[i - 1][2]); //左边有0个6,只要这个数不选6就行
        f[i][1] = f[i - 1][0];
        f[i][2] = f[i - 1][1];              //最左边只能选6
        g[i] = 10 * g[i - 1] + f[i - 1][2]; //由后面i - 1个数构成666,或者最左边选6再加两个6
    }

    int T;
    scanf("%lld", &T);

    while (T--)
    {
        int x;
        scanf("%lld", &x);
        int m = 3;//确定这个数的位数
        while (g[m] < x)
        {
            m++;
        }

        int k = 0;
        for (rint i = m; i >= 1; i--)
        //从左到右(从高位到低位)枚举
        // k记录左边有多少个连续的6,k = 3代表已经出现过666
        {
            for (rint j = 0; j <= 9; j++)
            {
                int cnt = g[i - 1];

                if (j == 6 || k == 3)
                {
                    //现在左边已经有k + (j == 6)个6,还差3 - k - (j == 6)个
					for (rint l = std::max(0ll, 3 - k - (j == 6)); l < 3; l++)
                    {
                        cnt += f[i - 1][l];
                    }
                }

                if (cnt >= x)
                {
                    printf("%lld", j);
                    if (k < 3)
                    {
                        if (j == 6)
                        {
                            k++;
                        }
                        else
                        {
                            k = 0;
                        }
                    }
                    break;
                }
                else
                {
                    x -= cnt;
                }
            }
        }
        puts("");
    }

    return 0;
}

Part5.状态压缩 dp

状态压缩 dp 通常用二进制数存储第二维,而第二维通常表示一种状态,比如一个序列要选合法的数字,我们就可以用第二维二进制存储,000111 就可以用来存储前三位不合法,后三位合法。

常见位运算:

~              取反
&              两个数有 0 就返回 0
|              两个数有 1 就返回 1
^              两个数相同则为 0 ,否则为 1
<<             左移
>>             右移
(n >> k) & 1   取 n 的第 k 位
n & ((1 << k) - 1)  取 n 的后 k 位
n ^ (1 << k)   将 n 的第 k 位取反
n | (1 << k)   将 n 的第 k 位改为 1
n & (~(1 << k)) 将 n 的第 k 位改为 0 

AcWing.291 蒙德里安的梦想

f[i][j] 表示 i - 1 列的方案数已经确定,从 i - 1 列伸出,并且第 i 列的状态是 j 的所有方案数

满足两个性质就可以进行转移:

  • jk 的方案截然相反
  • jk 方案中不能出现连续奇数的空位,否则会出现第 i - 1 位无法填放的问题

就可以写出代码了

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 12;
const int M = (1 << N) + 1;

int n, m;
long long f[N][M];
bool st[M];

int main()
{
    while (scanf("%d%d", &n, &m) and n != 0)
    {
        memset(f, 0, sizeof f);
        for (rint i = 0; i < 1 << n; i++)
        {
            int cnt = 0;
            st[i] = 1;
            for (rint j = 0; j < n; j++)
            {
                if (i >> j & 1)
                {
                    if (cnt & 1)
                    {
                        st[i] = 0;
                    }
                    cnt = 0;
                }
                else
                {
                    cnt++;
                }
            }
            if (cnt & 1)
            {
                st[i] = 0;
            }
        }

        f[0][0] = 1;

        for (rint i = 1; i <= m; i++)
        {
            for (rint j = 0; j < 1 << n; j++)
            {
                for (rint k = 0; k < 1 << n; k++)
                {
                    if ((j & k) == 0 && (st[j | k]))
                    {
                        f[i][j] += f[i - 1][k];
                    }
                }
            }
        }

        printf("%d\n", f[m][0]);
    }

    return 0;
}

[USACO]Corn Fields

f[i][j] 表示在前 i 行中(包括 i )在 j 个状态下的最大方案数。

g[i] 来表示第 i 行上的草地情况。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 13;
const int M = (1 << N) + 1;
const int mod = 1e8;

int n, m;
int a[N][N];
int g[M];
int f[N][M];
bool st[M];

int main()
{
    scanf("%d%d", &m, &n);

    for (rint i = 1; i <= m; i++)
    {
        for (rint j = 1; j <= n; j++)
        {
            scanf("%d", &a[i][j]);
        }
    }

    for (rint i = 1; i <= m; i++)
    {
        for (rint j = 1; j <= n; j++)
        {
            g[i] = (g[i] << 1) + a[i][j];
        }
    }

    for (rint i = 0; i < 1 << n; i++)
    {
        st[i] = ((i & (i << 1)) == 0) && ((i & (i >> 1)) == 0);
    }

    f[0][0] = 1;

    for (rint i = 1; i <= m; i++)
    {
        for (rint j = 0; j < 1 << n; j++)
        {
            for (rint k = 0; k < 1 << n; k++)
            {
                if (st[j] && ((j & g[i]) == j) && ((k & j) == 0))
                {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % mod;
                }
            }
        }
    }

    long long ans = 0;

    for (rint i = 0; i < 1 << n; i++)
    {
        ans = (ans + f[m][i]) % mod;
    }

    printf("%lld", ans);

    return 0;
}

[SCOI2007] 排列

f[i][j] 表示现在的状态为 i,当前所选的数组成的数字对 d 取余后的值为 j

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 11;
const int M = (1 << N) + 1;

int a[N];
int f[M][N * 100];
char s[N];
bool st[M];

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        int d;
        memset(f, 0, sizeof f);
        scanf("%s%d", s + 1, &d);
        int n = strlen(s + 1);
        for (rint i = 1; i <= n; i++)
        {
            a[i] = s[i] - '0';
        }

        f[0][0] = 1;

        for (rint i = 0; i < (1 << n) - 1; i++)
        {
            memset(st, 0, sizeof st);
            for (rint j = 1; j <= n; j++)
            {
                if (!(i & (1 << (j - 1))) && !st[a[j]])
                {
                    st[a[j]] = 1;
                    for (rint k = 0; k < d; k++)
                    {
                        f[i | (1 << (j - 1))][(k * 10 + a[j]) % d] += f[i][k];
                    }
                }
            }
        }

        printf("%d\n", f[(1 << n) - 1][0]);
    }

    return 0;
}

[USACO]Mix Cow

f[i][j] 表示以第 i 只奶牛为结尾的状态为 j 的队伍混乱的方案数是多少

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define int long long
#define rint register int
#define endl '\n'

const int N = 17;
const int M = (1 << N) + 1;

int f[N][M];
int n, m;
int a[N];
int ans;

signed main()
{
    scanf("%lld%lld", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
        f[i][1 << (n - i)] = 1;
    }

    for (rint j = 0; j < 1 << n; j++)
    {
        for (rint i = 1; i <= n; i++)
        {
            for (rint k = 1; k <= n; k++)
            {
                if (k != i && std::abs(a[k] - a[i]) > m && (j & (1 << (n - i))))
                {
                    f[i][j] += f[k][j ^ (1 << (n - i))];
                }
            }
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        ans += f[i][(1 << n) - 1];
    }

    printf("%lld", ans);

    return 0;
}

CF327E Axis Walking

很好的一道题目,借鉴了一下题解,话说 lowbit 记录最后一位这波操作是实在想不到的。

\(f_i\)​ 为选择状态为 \(i\) 的方案数,易知最终答案为 \(f_{2^n-1}\)

\[f_i = \sum_{i⊕j = 2^p}^{(0,i)}f_j \]

直接枚举所有的 \(j\) 会炸掉,拿一个东西记录最后一个位是什么,然后将一个原值等于 \(i\) 的寄存器异或掉她,然后拿这个寄存器的最后一位去异或 \(i\) 即可枚举到所有可能的 \(j\)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int mod = 1e9 + 7;
const int N = 24;
const int M = (1 << N) + 1;

int n, k;
int a[M], f[M], b[3];

int lowbit(int x)
{
    return x & (-x);
}

signed main()
{
    scanf("%d", &n);

    for (rint i = 1; i < 1 << n; i <<= 1)
    {
        scanf("%d", &a[i]);
    }

    scanf("%d", &k);

    for (rint i = 1; i <= k; i++)
    {
        scanf("%d", &b[i]);
    }

    f[0] = 1;

    for (rint i = 1; i < 1 << n; i++)
    {
        a[i] = a[i ^ lowbit(i)] + a[lowbit(i)];
        if (a[i] == b[1] || a[i] == b[2])
        {
            continue;
        }
        for (rint j = i; j; j ^= lowbit(j))
        {
            f[i] += f[i ^ lowbit(j)];
            f[i] %= mod;
        }
    }

    printf("%d", f[(1 << n) - 1]);

    return 0;
}

SP2829 TLE

nekko 哥哥和 hs_black 都切过的冷门题目,顺着状态压缩的标签找了过来。

f[i][j] 表示前 i 位的最后 m 位状态为 j 的方案数。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int mod = 1e9;
const int N = 51;
const int M = (1 << 15) + 1;

int n, m;
int f[N][M];
int c[N];
long long ans;

void clear()
{
    ans = 0;
    memset(f, 0, sizeof f);
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d%d", &n, &m);
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &c[i]);
        }
        clear();
        for (rint i = 0; i < 1 << m; i++)
        {
            if (i % c[1] != 0)
            {
                f[1][i] = 1;
            }
        }
        for (rint i = 2; i <= n; i++)
        {
            for (rint j = 0; j < 1 << m; j++)
            {
                f[i][j] = f[i - 1][j ^ ((1 << m) - 1)];
            }

            for (rint j = 0; j < m; j++)
            {
                for (rint k = 0; k < 1 << m; k++)
                {
                    if (!(k & (1 << j)))
                    {
                        f[i][k] = (f[i][k] + f[i][k | (1 << j)]) % mod;
                    }
                }
            }

            for (rint j = 0; j < 1 << m; j++)
            {
                if (j % c[i] == 0)
                {
                    f[i][j] = 0;
                }
            }
        }

        for (rint i = 0; i < 1 << m; i++)
        {
            ans = (ans + f[n][i]) % mod;
        }

        printf("%lld\n", ans);
    }

    return 0;
}

Part6.四边形不等式

\(w(x,y)\) 为定义在整数集合上的一个二元函数,若 \(∀a⩽b⩽c⩽d,w(a,c)+w(b,d)⩽w(a,d)+w(b,c)\),那么函数 \(w\) 满足四边形不等式。

四边形不等式的另一种定义

\(w(x,y)\) 为定义在整数集合上的一个二元函数,若 \(∀a<b,w(a,b)+w(a+1,b+1)⩽w(a+1,b)+w(a,b+1)\),那么函数 \(w\) 满足四边形不等式。

一般在考试的时候四边形不等式靠打表猜结论得出。

AcWing.2889 再探石子合并

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e3 + 5;
const int inf = 0x3f3f3f3f;

int n;
int a[N], s[N];
int f[N][N];
int p[N][N];

int cost(int l, int r)
{
    return s[r] - s[l - 1];
}

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        s[i] = s[i - 1] + a[i];
        p[i][i] = i;
    }

    for (rint len = 1; len < n; len++)
    {
        for (rint l = 1; l <= n - len; l++)
        {
            int r = l + len;
            f[l][r] = inf;
            for (rint k = p[l][r - 1]; k <= p[l + 1][r]; k++)
            {
                int t = f[l][k] + f[k + 1][r] + cost(l, r);
                if (f[l][r] > t)
                {
                    f[l][r] = t;
                    p[l][r] = k;
                }
            }
        }
    }

    printf("%d", f[1][n]);

    return 0;
}

[IOI2000] 邮局

\(w[i][j]\) 表示一个邮局覆盖的 \([i,j]\) 村庄之间的距离。

\(f[i][j]\) 表示前 \(i\) 个村庄建了 \(j\) 个邮局的最小值

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 3e3 + 5;
const int inf = 0x3f3f3f3f;

int n, m;
int a[N], s[N];
int f[N][N];
int p[N][N];
int w[N][N];

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        s[i] = s[i - 1] + a[i];
    }

    for (rint l = 1; l <= n; l++)
    {
        w[l][l] = 0;
        for (rint r = l + 1; r <= n; r++)
        {
            w[l][r] = w[l][r - 1] + a[r] - a[(l + r) >> 1];
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        p[i][i] = i;
        f[0][i] = inf;
    }

    for (rint len = 1; len <= n; len++)
    {
        for (rint l = 1; l <= n - len; l++)
        {
            int r = l + len;
            f[l][r] = inf;
            for (rint k = p[l][r - 1]; k <= p[l + 1][r]; k++)
            {
                int t = f[l - 1][k - 1] + w[k][r];
                if (f[l][r] > t)
                {
                    f[l][r] = t;
                    p[l][r] = k;
                }
            }
        }
    }

    printf("%d", f[m][n]);

    return 0;
}

4.基础数据结构

Part1.栈

AcWing.128 编译器

开两个栈,一个栈可以理解为“工具人”,只是因为光标移动帮忙存储,所以这个题目转化为:

1. I x:把x插入栈A,更新栈A的栈顶位置的前缀和,更新栈A的栈顶位置的最大前缀和
2. D:弹出栈A的栈顶元素
3. L:弹出栈A的栈顶元素并插入栈B中
4. R:弹出栈B的栈顶元素并插入栈A中,更新栈A的栈顶位置的前缀和,更新栈A的栈顶位置的最大前缀和
5. Q k:直接返回 f[k]

代码实现很简单了就:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int inf = 0x3f3f3f3f;

int t, x, s[N], f[N], now;
std::stack<int> a, b;
std::stack<int> c;

void clear()
{
    a = c;
    b = c;
}

int main()
{
    scanf("%d", &t);
    f[0] = -inf;
    s[0] = 0;
    clear();
    for (rint i = 1; i <= t; i++)
    {
        char op;
        int x;
        scanf("%s", &op);

        if (op == 'I')
        {
            scanf("%d", &x);
            a.push(x);
            s[a.size()] = s[a.size() - 1] + a.top();
            f[a.size()] = std::max(f[a.size() - 1], s[a.size()]);
        }

        if (op == 'D')
        {
            if (!a.empty())
            {
                a.pop();
            }
        }

        if (op == 'L')
        {
            if (!a.empty())
            {
                b.push(a.top());
                a.pop();
            }
        }

        if (op == 'R')
        {
            if (!b.empty())
            {
                a.push(b.top());
                b.pop();
                s[a.size()] = s[a.size() - 1] + a.top();
                f[a.size()] = std::max(f[a.size() - 1], s[a.size()]);
            }
        }

        if (op == 'Q')
        {
            scanf("%d", &x);
            printf("%d\n", f[x]);
        }
    }

    return 0;
}

【模板】单调栈

这个题其实是在教会我们单调栈的思想。

单调栈其实就是帮助我们筛选出一个单调自增的或者单调递减的序列来帮助我们更好的查询答案。

啥意思?以这个题目为例。

从最右边开始往左遍历,然后加入栈。

接着从下一个开始,如果当前元素比以栈顶元素作为下标的数小,那么弹出栈顶,知道找到比自己大的,

然后更新答案,将当前元素压入栈。

由于每次找到的都是比自己打的,很显然这个栈具有单调性,所以这个做法是正确的。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>

#define rint register int
#define endl '\n'

const int N = 3e6 + 5;

int n, a[N];
int ans[N]; 
std::stack<int> s;
                 
int main()
{
    scanf("%d", &n);
    
    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);		
    }

    for (rint i = n; i >= 1; i--)
    {
        while (!s.empty() && a[s.top()] <= a[i])
        {
            s.pop(); 			
	}
		
        if(s.empty()) 
            ans[i] = 0;
	else  
	    ans[i] = s.top();
		
        s.push(i);                     
    }
    
    for (rint i = 1; i <= n; i++)
    {
        printf("%d ", ans[i]); 		
    }

    return 0;
}

SP1805 HISTOGRA

我们可以以当前的高度, 向左右两边扩展

所以这样子问题就转换成为 :

当前这个数左边第一个比它小的数,当前这个数右边第一个比它小的数

l[],r[] 存储边界下标。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n;
int h[N];
int l[N], r[N];

std::stack<int> s;
std::stack<int> clear;

void get_l()
{
    s.push(0); 
    // 防止空栈,也可以在循环里加个特判,都一样
    for (rint i = 1; i <= n; i++)
    {
        while (h[s.top()] >= h[i])
        {
            s.pop();
        }
        l[i] = s.top();
        s.push(i);
    }
}

void get_r()
{
    s.push(n + 1);
    for (rint i = n; i > 0; i--)
    {
        while (h[s.top()] >= h[i])
        {
            s.pop();
        }
        r[i] = s.top();
        s.push(i);
    }
}

int main()
{
    while (~scanf("%d", &n) and n != 0)
    {
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &h[i]);
        }

        h[0] = h[n + 1] = -1;

        get_l();
        s = clear;
        get_r();

        long long ans = 0;

        for (rint i = 1; i <= n; i++)
        {
            ans = std::max(ans, (long long)h[i] * (r[i] - l[i] - 1));
        }

        printf("%lld\n", ans);
    }

    return 0;
}

[POI2008] PLA

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>

#define rint register int
#define endl '\n'

const int N = 3e5 + 5;

int n, h[N];
std::stack<int> s;
int cnt = 0;

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        int x;
        scanf("%d%d", &x, &h[i]);
    }
    
    s.push(0);    //防止空栈 
    for (rint i = 1; i <= n; i++)
    {
        while (!s.empty() && h[s.top()] > h[i])
        {
            s.pop();
        }
        //如果栈顶不等于建筑高度,就必定多用一张海报
        if (h[s.top()] != h[i])
        {
            cnt++;
        }
        s.push(i);
    }

    printf("%d", cnt);

    return 0;
}

Part2.队列

[NOIP2016] 蚯蚓

开三个队列模拟优先队列;

一个队列维护切后的第一段,一个队列维护切后的第二段,另外一个队列,里面存储蚯蚓长度,长度是从高到低,排好序的长度,那么每一次将被切断的蚯蚓,肯定是这三个队列的队头,因为我们这道题目具有单调递减的性质

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

#define int long long

const int N = 1e7 + 5;

std::queue<int> p1, p2, p3;
int n, m, q, u, v, t, a[N], data;

int cmp(int a, int b)
{
    return a > b;
}

int query_max(int t)
{
    int maxn = -1;
    int a = -1, b = -1, c = -1;
    
    if (!p1.empty()) a = p1.front() + t * q;
    if (!p2.empty()) b = p2.front() + t * q;
    if (!p3.empty()) c = p3.front() + t * q;
    
    maxn = std::max(a, std::max(b, c));
    
    if (maxn == a) p1.pop();
    else if (maxn == b) p2.pop();
    else if (maxn == c) p3.pop();
    
    return maxn;
}

signed main()
{
    std::cin >> n >> m >> q >> u >> v >> t;
    
    for(rint i = 1; i <= n; i++)
    {
	scanf("%lld", &a[i]);
    }
	
    std::sort(a + 1, a + 1 + n, cmp);
    
    for(rint i = 1; i <= n; i++)
    {
        p1.push(a[i]);		
    }

    for(rint i = 1; i <= m; i++)
    {
        int x = query_max(i - 1);
        if (!(i % t))
        {
	    printf("%lld ", x);
	}
		
        int now1 = x * u / v;
        int now2 = x - now1;
        
        p2.push(now1 - i * q);
        p3.push(now2 - i * q);
    }
    
    std::cout << endl;
    
    for(rint i = 1; i <= n + m; i++)
    {
        long long x = query_max(m); 
	if (i % t == 0)
        {
	    printf("%lld ", x);
	}
    }
    
    return 0;
}

P1440 区间最小值

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e6 + 5;

int n, m;
int a[N], q[N];
int hh = 1, tt = 0;

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        printf("%d\n", a[q[hh]]);
        
        while (i - q[hh] + 1 > m && hh <= tt)
            hh++;
        while (a[i] < a[q[tt]] && hh <= tt)
            tt--;
            
        q[++tt] = i;
    }

    return 0;
}

P1714 切蛋糕

就是求最大子序和。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;
const int inf = 0x3f3f3f3f;

int n, m;
int a[N];
int q[N];
long long s[N];
int hh = 0, tt = 0;

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        s[i] += s[i - 1] + a[i];
    }

    long long ans = -inf;

    for (rint i = 1; i <= n; i++)
    {
        while (i - q[hh] > m && hh <= tt)
            hh++;
        while (s[i] <= s[q[tt]] && hh <= tt)
            tt--;
	ans = std::max(ans, s[i] - s[q[hh]]);
        q[++tt] = i;
    }

    printf("%lld", ans);

    return 0;
}

Part3.二叉堆

大根堆 
std::priority_queue<int, std::vector<int> > q;
小根堆
std::priority_queue<int, std::vector<int>, std::greater<int> > q;

P1090 合并果子

易证不断取最小的两堆合并成较大的一堆是最优的。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

std::priority_queue<int, std::vector<int>, std::greater<int>> q;
int n;

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        int x;
        scanf("%d", &x);
        q.push(x);
    }

    long long ans = 0;

    while (q.size() - 1)
    {
        int x = q.top();
        q.pop();
        int y = q.top();
        q.pop();
        //取出最小的两个 

        int sig = x + y;

        q.push(sig);
        ans += sig;
    }

    printf("%lld", ans);

    return 0;
}

P1168 中位数

开一个 mid,维护中位数

然后开一个小根堆和大根堆,大根堆的数字个数比小根堆的大,那么弹出大根堆堆顶,作为新的 mid ,小根堆堆顶加入原来的 mid

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 1e7 + 5;

std::priority_queue<int, std::vector<int>> q1;
//大根堆, 记录比 mid 小的数字
std::priority_queue<int, std::vector<int>, std::greater<int>> q2;
//小根堆, 记录比 mid 大的数字
int n;
int a[N];

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
    }

    int mid = a[1];

    printf("%d\n", mid);

    for (rint i = 2; i <= n; i++)
    {
        if (a[i] > mid)
        {
            q2.push(a[i]);
        }
        else
        {
            q1.push(a[i]);
        }

        if (!(i % 2))
        {
            continue;
        }

        while (q1.size() != q2.size())
        {
            if (q1.size() > q2.size())
            {
                q2.push(mid);
                mid = q1.top();
                q1.pop();
            }
            else
            {
                q1.push(mid);
                mid = q2.top();
                q2.pop();
            }
        }

        printf("%d\n", mid);
    }

    return 0;
}

P1631 序列合并

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

std::priority_queue<int, std::vector<int>, std::greater<int>> q;
int n;
int a[N], b[N];

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
    }

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &b[i]);
    }
    //题目中已经说了是递增的序列 

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 1; j <= n; j++)
        {
            int sig = a[i] + b[j];
            //有i ×j个和,所以当 i ×j >= n 时 break 就行
            q.push(sig);
            if (i * j >= n)
            {
                break;
            }
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        printf("%d ", q.top());
        q.pop();
    }

    return 0;
}

5.字符串专题

Part1.Trie 字典树

字典树是啥,这个东西简单来说就是一个词典,然后快速帮你查出你所统计过的由 26 个字母组成的单词(稍微改一改也可以实现别的功能)

先看一个模板题练练手

AcWing.835 Trie 字符串

名副其实的模板了属于是,这个是 Trie 的起步题目了;

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int ch[N][26], cnt[N];
int tot;
char s[N];

void insert(char s[])
{
    int p = 0;
    for (rint i = 0; s[i]; i++)
    {
        int u = s[i] - 'a';
        if (!ch[p][u])
        {
            ch[p][u] = ++tot;
        }
        p = ch[p][u];
    }
    cnt[p]++;
}

int query(char s[])
{
    int p = 0;
    for (rint i = 0; s[i]; i++)
    {
        int u = s[i] - 'a';
        if (!ch[p][u])
        {
            return 0;
        }
        p = ch[p][u];
    }
    return cnt[p];
}

int main()
{
    int T;
    scanf("%d", &T);

    while (T--)
    {
        char op;
        std::cin >> op;
        scanf("%s", s);

        if (op == 'I')
        {
            insert(s);
        }
        else
        {
            printf("%d\n", query(s));
        }
    }

    return 0;
}

[TJOI2010] 阅读理解

非常裸的字典树,但是你会发现第 11 个点过不去。

数组开大了就炸空间了,怎么办!!!

bitset 代替 bool 数组就可以了

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <bitset>

#define rint int
#define endl '\n'

const int N = 3e5 + 5;
const int M = 5e5 + 5;

int n, m;
char s[N];
int tot;
int ch[N][26];
std::bitset<1005> v[M];

void insert(char *s, int x)
{
    int p = 0;
    for (rint i = 0; s[i]; i++)
    {
        int u = s[i] - 'a';
        if (!ch[p][u])
        {
            ch[p][u] = ++tot;
        }
        p = ch[p][u];
    }
    v[p][x] = 1;
}

void query_print(char *s)
{
    int p = 0;
    for (rint i = 0; s[i]; i++)
    {
        int u = s[i] - 'a';
        if (!ch[p][u])
        {
            printf(" ");
            puts("");
            return;
        }
        p = ch[p][u];
    }
    for (rint i = 1; i <= n; i++)
    {
        if (v[p][i] == 1)
        {
            printf("%d ", i);
        }
    }
    puts("");
}

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        int x;
        scanf("%d", &x);
        for (int j = 1; j <= x; j++)
        {
            scanf("%s", s);
            insert(s, i);
        }
    }

    scanf("%d", &m);

    for (rint i = 1; i <= m; i++)
    {
        scanf("%s", s);
        query_print(s);
    }

    return 0;
}

[USACO] Secret

sum[i] 表示有多少个串经过节点 icnt[i] 表示有多少个串以节点 i 为结尾。

然后沿着 \(Trie\) 扫一遍,ans += cnt[i]

最后 ans − cnt[p] + sum[p] 就是答案。

因为如果这个串是其他串的前缀,那么其他的串一定经过这个串的结尾,总数量为 sum[i]

如果其他串是这个串的前缀,那么这个串一定经过其他串的结尾标记,所以把沿路的结尾标记加起来。

但是这样的话相匹配的串算了两次,所以要减去一个。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;

int n, m, len;
int ch[N][2];
int tot;
int cnt[N], sum[N];
int s[N];

void insert(int s[])
{
    int p = 0;
    for (rint i = 0; i < len; i++)
    {
        int u = s[i];
        if (ch[p][u] == -1)
        {
            ch[p][u] = ++tot;
        }
        p = ch[p][u];
        sum[p]++;
    }
    cnt[p]++;
}

int query(int s[])
{
    int p = 0;
    int ans = 0;
    for (rint i = 0; i < len; i++)
    {
        int u = s[i];
        if (ch[p][u] == -1)
        {
            return ans;
        }
        p = ch[p][u];
        ans += cnt[p];
    }
    return ans - cnt[p] + sum[p];
}

int main()
{
    scanf("%d%d", &m, &n);
    memset(ch, -1, sizeof ch);

    for (rint i = 1; i <= m; i++)
    {
        scanf("%d", &len);
        for (rint j = 0; j < len; j++)
        {
            scanf("%d", &s[j]);
        }
        insert(s);
    }

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &len);
        for (rint j = 0; j < len; j++)
        {
            scanf("%d", &s[j]);
        }

        printf("%d\n", query(s));
    }

    return 0;
}

UVA1401 RemWord

需要字典树 + dp

\(f[i]\) 表示 \((i−len)\) 的串有多少种表示方法

\(f[i]=∑(f[j+1])\),且要满足 \(s[i..j+1]\) 可以由多个字典拼成,这个可以用字典树实现。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <bitset>

#define rint register int
#define endl '\n'

const int mod = 20071027;
const int N = 3e5 + 5;
const int M = 1e6 + 5;

char s[N];
char c[M];
int n, tot;
int ch[M][26];
int f[M];
int len;
std::bitset<N> v;
std::bitset<N> empty__;

void clear()
{
    tot = 0;
    v = empty__;
    memset(ch, 0, sizeof ch);
    memset(f, 0, sizeof f);
}

void insert(char s[])
{
    int p = 0;
    for (rint i = 0; s[i]; i++)
    {
        int u = s[i] - 'a';
        if (!ch[p][u])
        {
            ch[p][u] = ++tot;
        }
        p = ch[p][u];
    }
    v[p] = 1;
}

void query(char s[])
{
    int len = strlen(s);
    for (rint i = len - 1; i >= 0; i--)
    {
        int p = 0;
        for (rint j = i; j < len; j++)
        {
            int u = s[j] - 'a';
            if (!ch[p][u])
            {
                break;
            }
            p = ch[p][u];
            if (v[p])
            {
                f[i] = (f[j + 1] + f[i] + mod) % mod;
            }
        }
    }
}

int main()
{
    int idx = 0;

    while (~scanf("%s", c))
    {
        clear();
        scanf("%d", &n);

        for (rint i = 1; i <= n; i++)
        {
            scanf("%s", s);
            insert(s);
        }

        f[strlen(c)] = 1;
        query(c);

        printf("Case %d: %d\n", ++idx, f[0]);
    }

    return 0;
}

Part2.01 Trie

01 Trie 一般处在与状态压缩或者二进制有关的题目中。

以一道经典的 01 Trie 例题为例子。

P4551 最长异或路径

先考虑这个题的简单版,在给定的 \(n\) 个整数中选出两个进行异或运算,得到的结果最大是多少

如果数据范围很小的话,\(O(n^2)\) 就可以。

for (rint i = 1; i <= n; i++)
{
    for (rint j = 1; j < i; j++)
    {
        ans = max(ans, a[i] ^ a[j]);
    }
}

考虑如何优化里面那一层循环。

将两个数字异或,肯定是想让两个数(二进制)对应的位数长的不一样才好,一个是 1,一个是 0,显然是最佳的。

这个时候就可以建一颗 01 Trie 来解决。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define int long long
#define rint register int
#define endl '\n'

const int N = 1e6 + 5;

int h[N], e[N], ne[N], w[N], idx;
int d[N];
int ch[N][2];
int tot;
int n;

void add(int a, int b, int c)
{
    e[++idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx;
}

void dfs(int x, int fa)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        int z = w[i];
        if (y != fa)
        {
            d[y] = d[x] ^ z;
            dfs(y, x);
        }
    }
}

void insert(int val, int p)
{
    for (rint i = (1 << 30); i; i >>= 1)
    {
        bool u = val & i;
        if (!ch[p][u])
        {
            ch[p][u] = ++tot;
        }
        p = ch[p][u];
    }
}

int query(int val, int p)
{
    int ans = 0;
    for (rint i = (1 << 30); i; i >>= 1)
    {
        bool u = val & i;
        if (ch[p][!u])
        {
            ans += i;
            p = ch[p][!u];
            continue;
        }
        p = ch[p][u];
    }
    return ans;
}

signed main()
{
    scanf("%lld", &n);

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        scanf("%lld%lld%lld", &a, &b, &c);
        add(a, b, c);
        add(b, a, c);
    }

    dfs(1, 0);

    for (rint i = 1; i <= n; i++)
    {
        insert(d[i], 0);
    }

    int res = 0;

    for (rint i = 1; i <= n; i++)
    {
        res = std::max(res, query(d[i], 0));
    }

    printf("%lld", res);

    return 0;
}

Part3.KMP

KMP 是一种将字符串进行匹配的算法,可以求出两个字符串的公共部分的起止下标和最长 border 长度

P3375 【模板】KMP

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;

char p[N];
char s[N];
int ne[N];

int main()
{
    scanf("%s%s", s + 1, p + 1);

    int n = strlen(s + 1);
    int m = strlen(p + 1);

    for (rint i = 2, j = 0; i <= m; i++)
    {
        while (j > 0 && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j++;
        ne[i] = j;
    }

    for (rint i = 1, j = 0; i <= n; i++)
    {
        while (j > 0 && s[i] != p[j + 1]) j = ne[j];
        if (s[i] == p[j + 1]) j++;
        if (j == m)
        {
            printf("%d\n", i - m + 1);
            j = ne[j];
        }
    }

    for (rint i = 1; i <= m; i++)
    {
        printf("%d ", ne[i]);
    }

    return 0;
}

[Cnoi2021] 符文破译

题目说的有点难懂,简化一下。

给定字符串 \(\texttt{S}\)\(\texttt{T}\),要求将 \(\texttt{S}\) 划分成最少的段,且每段是 \(\texttt{T}\) 的前缀。

假如 \(S\) 串可以由 \(T\) 串的前缀组成,那么一定存在等式 \(S[1,i] = S[1,j] + T[1,k]\) \((1≤j≤i)\)

很显然,\(T[1,k] = S[1,i] - S[1,j] = S[j + 1,i]\)

\(S[1,i]\) 的后缀一定是串 \(T\) 的某个前缀。

考虑 dp,设 \(f[i]\) 表示串 \(S[1,i]\) 最少能被分成多少个串 \(T\) 的前缀。

\(f[i]=min(f[j]+1),S[j+1,i]=T[1,i−j]\)

然而怎么判断 \(S[j+1,i]=T[1,i−j]\),发现可以使用 KMP,那么这个题就切掉了。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 1e7 + 5;

char t[N], s[N];
int n, m;
int ne[N];
int f[N];

void get_next(char p[], int n)
{
    for (rint i = 2, j = 0; i <= n; i++)
    {
        while (j > 0 && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j++;
        ne[i] = j;
    }
}

void KMP(char p[], char s[], int m)
{
    for (rint i = 1, j = 0; i <= m; i++)
    {
        while (j > 0 && s[i] != p[j + 1]) j = ne[j];
        if (s[i] == p[j + 1]) j++;
        f[i] = std::min(f[i], f[i - j] + 1);
    }
}

int main()
{
    scanf("%d%d%s%s", &n, &m, t + 1, s + 1);
    memset(f, 0x3f, sizeof f);
    f[0] = 0;
    get_next(t, n);
    KMP(t, s, m);
    if (f[m] == inf)
    {
        puts("Fake");
        return 0;
    }
    printf("%d", f[m]);
    return 0;
}

[POI2006] OKR

首先要想想,这个题的周期到底是啥,

是某段后缀与某段前缀相等且没有重叠部分。

对于每个前缀 \(i\),定义 \(j=i\),然后不断去更新 \(j=next[j]\),最小的 \(j\) 就是答案,此时 \(ans+=i-j\)

中间需要加一个类似于记忆化的玩意儿,因为这个挂了 30pts

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 1e7 + 5;

char s[N];
int n;
int ne[N];
int f[N];
long long ans;

void get_next(char p[], int n)
{
    for (rint i = 2, j = 0; i <= n; i++)
    {
        while (j > 0 && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j++;
        ne[i] = j;
    }
}

int main()
{
    scanf("%d%s", &n, s + 1);
    get_next(s, n);
    for (rint i = 2, j = 2; i <= n; i++, j = i)
    {
        while (ne[j]) j = ne[j];
        if (ne[i]) ne[i] = j;   //记忆化
        ans += i - j;
    }
    printf("%lld", ans);
    return 0;
}

UVA1328 Period

算法竞赛进阶指南里 KMP 那道例题的原型。

引理:

\(S[1 ,next[i]] = S[i-next[i]+1 , i]\) 一定相等且最大,且循环节长度是 \(i / (i - next[i])\)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 1e7 + 5;

char s[N];
int n;
int ne[N];
int f[N], h[N];
int cnt[N];

void get_next(char p[], int n)
{
    for (rint i = 2, j = 0; i <= n; i++)
    {
        while (j > 0 && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j++;
        ne[i] = j;
    }
}

void print()
{
    for (rint i = 2; i <= n; i++)
    {
        if (!(i % (i - ne[i])) && i > (i - ne[i]))
        {
            printf("%d %d\n", i, i / (i - ne[i]));
        }
    }
}

int main()
{
    int idx = 0;
    while (~scanf("%d", &n) and n != 0)
    {
        scanf("%s", s + 1);
        printf("Test case #%d\n", ++idx);
        get_next(s, n);
        print();
        puts("");
    }
    return 0;
}

[POI2005] SZA

PS:此题贺的 i207M 学长的题解。

\(f[i]\) 表示前缀 \(i\) 的答案。

\(f[i]\) 只有 \(2\) 种取值:\(i,f[next[i]]\),因为想覆盖 \(i\) 至少覆盖 \(next[i]\)

当答案为 \(f[next[i]]\) 的时候,因为前缀 \(i\) 的最后几个字符为 \(next[i]\),所以我们最多在后面接上 \(next[i]\) 这么长的字符串,也就是充要条件为存在一个 \(j\),使得 \(f[j]=f[next[i]],i−next[i]≤j\),可以用桶实现。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 0x3f3f3f3f;
const int N = 1e7 + 5;

char s[N];
int n;
int ne[N];
int f[N], h[N];
int cnt[N];

void get_next(char p[], int n)
{
    for (rint i = 2, j = 0; i <= n; i++)
    {
        while (j > 0 && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j++;
        ne[i] = j;
    }
}

int main()
{
    scanf("%s", s + 1);
    n = strlen(s + 1);
    get_next(s, n);
    f[1] = 1;
    for (rint i = 1; i <= n; i++)
    {
        f[i] = i;
        if (cnt[f[ne[i]]] >= i - ne[i])
        {
            f[i] = f[ne[i]];
        }
        cnt[f[i]] = i;
    }
    printf("%d", f[n]);
    return 0;
}

Part4.Hashmap<string, int>

Hash 就是可以理解为给每一个字符串打上一个数值,然后进行解决。

P3370 【模板】哈希

先考虑 STL 写法,开一个 map,然后记录,先定义 ans = n ,如果这个字符重复出现过了,ans 自减。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>

#define rint register int
#define endl '\n'

std::map<std::string, bool> m;
//定义一个以string类型为下标的bool数组
std::string s;
int n, ans;

int main()
{
    scanf("%d", &n);
    ans = n;

    for (rint i = 1; i <= n; i++)
    {
        std::cin >> s;
        if (m[s])
        {
            ans--;
        }
        else
        {
            m[s] = true;
        }
    }
    printf("%d", ans);

    return 0;
}

接着是正规 Hash 做法:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int base = 131;
const long long mod = 11451419198ll;
const int N = 1e6 + 5;

int n;
unsigned long long a[N];
char s[N];

unsigned long long Hash(char s[])
{
    int len = strlen(s + 1);
    unsigned long long ans = 0;
    for (rint i = 1; i <= len; i++)
    {
        ans = (ans * base + (unsigned long long)(s[i])) % mod;
    }
    return ans;
}

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%s", s + 1);
        a[i] = Hash(s);
    }
    std::sort(a + 1, a + n + 1);
    long long ans = 1;
    for (rint i = 1; i < n; i++)
    {
        if (a[i] != a[i + 1])
        {
            ans++;
        }
    }
    printf("%lld", ans);
    return 0;
}

P1381 单词背诵

此题用 Hash 和 map,均可,由于主程序处理较为麻烦,所以选择了 map

第一问很好解决,重点在于第二问。

第二问可以用类似于单调队列的东西维护一下。

\(1...m\) 枚举作为 \(tt\) 就可以了。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>

#define rint register int
#define endl '\n'

using std::map;
using std::string;

const int N = 1e5 + 5;

map<string, int> sum;
map<string, bool> v;

int ans1, ans2;
int n, m;
int hh = 1;
string s[N];

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
        string a;
        std::cin >> a;
        v[a] = 1;
    }
    scanf("%d", &m);
    for (rint i = 1; i <= m; i++)
    {
        std::cin >> s[i];
        if (v[s[i]])
        {
            sum[s[i]]++;
        }
        if (sum[s[i]] == 1)
        {
            ans1++;
            ans2 = i - hh + 1;
        }
        while (hh <= i)
        {
            if (!v[s[hh]])
            {
                hh++;
                continue;
            }
            if (sum[s[hh]] >= 2)
            {
                sum[s[hh]]--;
                hh++;
                continue;
            }
            break;
        }
        ans2 = std::min(ans2, i - hh + 1);
    }

    printf("%d\n%d", ans1, ans2);

    return 0;
}

[JLOI2011] 不重复数字

懒人在用 map。。。。

虽然这个题加强数据了,但是只要不用 cin 还是可以过去的

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>

#define rint register int
#define endl '\n'

using std::map;

const int N = 1e5 + 5;

int T, n;
map<int, bool> v;

int main()
{
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n);
        v.clear();
        for (rint i = 1; i <= n; i++)
        {
            int x;
            scanf("%d", &x);
            if (!v[x])
            {
                printf("%d ", x);
                v[x] = true;
            }
        }
        puts("");
    }
    return 0;
}

[CTSC2014] 企鹅 QQ

给定 n 个字符串 问两个字符串只差一个字符的字符串对的数量

分别按照从前往后的顺序和从后往前的顺序生成两个 hash 数组

然后考虑枚举哪一位是不相同的,只需要将 hash1[j − 1]hash[j + 1] 的值加起来就可以的到删除掉以为之后的 hash 值。但是要开二维的,第一维记录字符串编号。

注意 base,以后用 2333 就可以了,用别的什么 114514,质数什么的都过不去的。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 3e4 + 5;
const int M = 2e2 + 30;
const int base1 = 2333;
const int base2 = 23333;
char a[M];
unsigned long long h1[N][M];
unsigned long long h2[N][M];
int n, l, s, cnt;
unsigned long long ans[N];

void prepare(int x)
{
    for (rint i = 1; i <= l; i++)
    {
        h1[x][i] = h1[x][i - 1] * base1 + a[i];
    }
    for (rint i = l; i >= 1; i--)
    {
        h2[x][i] = h2[x][i + 1] * base2 + a[i];
    }
}

unsigned long long Hash(int i, int j)
{
    return h1[j][i - 1] * (base1 + 666) + h2[j][i + 1] * (base2 + 666);
}

int main()
{
    scanf("%d%d%d", &n, &l, &s);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%s", a + 1);
        prepare(i);
    }
    for (rint i = 1; i <= l; i++)
    {
        for (rint j = 1; j <= n; j++)
        {
            ans[j] = Hash(i, j);
        }
        std::sort(ans + 1, ans + 1 + n);
        int num = 1;
        for (rint j = 2; j <= n; j++)
        {
            if (ans[j] == ans[j - 1])
            {
                cnt += num;
                num++;
            }
            else
            {
                num = 1;
            }
        }
    }
    printf("%d", cnt);
    return 0;
}

[POI2006] PAL

\(n\) 个回文串,任选两个回文串收尾相接组成 \(n^2\) 个字符串(可以自己于自己组合)。求这 \(n^2\) 个字符串中有几个是回文串

一个结论:当且仅当 \(a ,b\) 两个回文串的最小循环节一样时, \(ab\) 是回文串。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <set>
#include <math.h>

#define rint register int
#define endl '\n'

#define minn *s.lower_bound(0)

typedef unsigned long long ull;

const int base = 233;
const int N = 2e6 + 5;

ull h[N], p[N];
int n, m;
long long ans;
char a[N];
std::multiset<ull> s;

ull calc(int l, int r)
{
    return h[r] - h[l - 1] * p[r - l + 1];
}

bool check(int x)
{
    ull tmp = calc(1, x);
    for (rint i = 1; i <= m; i += x)
    {
        if (calc(i, i + x - 1) != tmp)
        {
            return 0;
        }
    }
    s.insert(tmp);
    return 1;
}

int main()
{
    p[0] = 1;
    scanf("%d", &n);
    for (rint i = 1; i <= 2000000; i++)
    {
        p[i] = p[i - 1] * base;
    }
    for (rint j = 1; j <= n; j++)
    {
        scanf("%d%s", &m, a + 1);
        for (rint i = 1; i <= m; i++)
        {
            h[i] = h[i - 1] * base + a[i];
        }
        for (rint i = 1; i <= m; i++)
        {
            if (m % i)
            {
                continue;
            }
            if (check(i))
            {
                break;
            }
        }
    }

    while (!s.empty())
    {
        long long cnt = s.count(minn);
        ans += cnt * cnt;
        s.erase(minn);
    }

    printf("%lld", ans);

    return 0;
}

[POI2012] OKR

只需要找最小的 \(len\) 使得 \(hash(l+len,r)=hash(l,r−len)\)

循环节的长度的循环次数都一定是总长的约数,总长除掉循环次数。这里就需要把 \(len\) 分解质因数。我用的是埃氏筛存储最小质因子。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <set>
#include <math.h>

#define rint register int
#define endl '\n'

typedef unsigned long long ull;

const int N = 5e5 + 5;
const int base = 233;

int n, m;
int len;
long long ans, g[N];
// g[] 记录最小质因数
ull h[N], p[N];
char s[N];

void get_prime()
{
    for (rint i = 1; i <= n; i++)
    {
        g[i] = i;
        if (!(g[i] % 2))
        {
            g[i] = 2;
        }
    }

    for (rint i = 3; i < sqrt(n); i += 2)
        for (rint j = i * 2; j <= n; j += i)
            if (g[j] > i)
                g[j] = i;
}

ull calc(int l, int r)
{
    return h[r] - h[l - 1] * p[r - l + 1];
}

int main()
{
    scanf("%d%s", &n, s + 1);
    get_prime();
    p[0] = 1;
    for (rint i = 1; i <= n; i++)
    {
        h[i] = h[i - 1] * base + s[i];
        p[i] = p[i - 1] * base;
    }

    scanf("%d", &m);

    for (rint i = 1; i <= m; i++)
    {
        int l, r;
        scanf("%d%d", &l, &r);
        ans = len = r - l + 1;

        while (len > 1)
        {
            if (calc(l + ans / g[len], r) == calc(l, r - ans / g[len]))
            {
                ans /= g[len];
            }
            len /= g[len];
        }
        printf("%lld\n", ans);
    }

    return 0;
}

6.最小生成树

根据经验来看,大部分的题目都可以使用 Kruskal 进行解决,如果这个题可以用堆优化 Prim 解决,那估计大概率 Kruskal 也能过,所以题目均使用 Kruskal 解决.

【模板】最小生成树

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 5e3 + 5;
const int M = 4e5 + 5;
const int inf = 0x3f3f3f3f;

struct rec
{
    int x, y, z;
} edge[M];

int n, m;
int fa[N];
int ans;

int get(int x)
{
    return x == fa[x] ? x : fa[x] = get(fa[x]);
}

bool operator<(rec a, rec b)
{
    return a.z < b.z;
}

int kruskal()
{
    std::sort(edge + 1, edge + m + 1);
    int idx = 0;
    for (rint i = 1; i <= n; i++)
    {
        fa[i] = i;
    }
    for (rint i = 1; i <= m; i++)
    {
        int x = get(edge[i].x);
        int y = get(edge[i].y);
        if (x == y)
        {
            continue;
        }
        fa[x] = y;
        ans += edge[i].z;
        idx++;
    }
    if (idx < n - 1)
    {
        return inf;
    }
    return ans;
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        scanf("%d%d%d", &edge[i].x, &edge[i].y, &edge[i].z);
    }

    ans = kruskal();

    if (ans != inf)
    {
        printf("%d", ans);
        return 0;
    }

    puts("orz");

    return 0;
}

P1195 口袋的天空

非常板子的一道题目。

他和正常的 Kruscal 唯一的区别就是要 check 一下是否达到了界限。如果 idx + k = n 了,结束操作。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5;
const int M = 2e4 + 5;
const int inf = 0x3f3f3f3f;

struct rec
{
    int x, y, z;
} edge[M];

int n, m, k;
int fa[N];
int ans;

int get(int x)
{
    return x == fa[x] ? x : fa[x] = get(fa[x]);
}

bool operator<(rec a, rec b)
{
    return a.z < b.z;
}

int kruskal()
{
    std::sort(edge + 1, edge + m + 1);
    int idx = 0;
    for (rint i = 1; i <= n; i++)
    {
        fa[i] = i;
    }
    for (rint i = 1; i <= m; i++)
    {
        int x = get(edge[i].x);
        int y = get(edge[i].y);
        if (x == y)
        {
            continue;
        }
        fa[x] = y;
        ans += edge[i].z;
        idx++;
        if (idx == n - k)
        {
            return ans;
        }
    }
    return -1;
}

int main()
{
    scanf("%d%d%d", &n, &m, &k);

    for (rint i = 1; i <= m; i++)
    {
        scanf("%d%d%d", &edge[i].x, &edge[i].y, &edge[i].z);
    }

    ans = kruskal();

    if (ans != -1)
    {
        printf("%d", ans);
        return 0;
    }

    puts("No Answer");

    return 0;
}

[USACO] Watering

注意边的大小设置,开到极限就挺好。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e3 + 5;
const int M = 2e7 + 5;
const int inf = 0x7f7f7f7f;

struct rec
{
    int x, y, z;
} edge[M];

int n, m, cost;
int fa[N];
int ans;
int px[M], py[M];
int id = 0;

int get(int x)
{
    return x == fa[x] ? x : fa[x] = get(fa[x]);
}

bool operator<(rec a, rec b)
{
    return a.z < b.z;
}

int kruskal()
{
    std::sort(edge + 1, edge + m + 1);
    int idx = 0;
    for (rint i = 1; i <= n; i++)
    {
        fa[i] = i;
    }
    for (rint i = 1; i <= m; i++)
    {
        int x = get(edge[i].x);
        int y = get(edge[i].y);
        if (x == y)
        {
            continue;
        }
        fa[x] = y;
        ans += edge[i].z;
        idx++;
    }
    if (idx < n - 1)
    {
        return inf;
    }
    return ans;
}

int main()
{
    scanf("%d%d", &n, &cost);

    for (rint i = 1; i <= n; i++)
    {
        scanf("%d%d", &px[i], &py[i]);
        for (rint j = 1; j < i; j++)
        {
            int d = (px[i] - px[j]) * (px[i] - px[j]) + (py[i] - py[j]) * (py[i] - py[j]);
            if (d >= cost)
            {
                edge[++m].x = i;
                edge[m].y = j;
                edge[m].z = d;
            }
        }
    }

    ans = kruskal();

    if (ans != inf)
    {
        printf("%d", ans);
        return 0;
    }

    puts("-1");

    return 0;
}

AcWing.346 走廊泼水节

s[i] 存储连通块 i 的元素个数。

\(A\) 与点 \(B\) 之间的距离,必须要大于之前的边权,否则就会破坏之前的最小生成树。所以两个点距离最小为 \(edge[i].w + 1\)

如果 \(s[A]\)\(s[B]\) 所有点相连,那么就会产生 \(s[A] * s[B] - 1\) 条边。

\(∴(w+1)×(s[A] * s[B] - 1)\) 为两个连通块成为完全图的最小代价

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 1e4 + 5;
const int M = 4e5 + 5;
const int inf = 0x3f3f3f3f;

struct rec
{
    int x, y, z;
} edge[M];

int n;
int fa[N];
int s[N];
int ans;

int get(int x)
{
    return x == fa[x] ? x : fa[x] = get(fa[x]);
}

bool operator<(rec a, rec b)
{
    return a.z < b.z;
}

int kruscal()
{
    std::sort(edge + 1, edge + n);
    for (rint i = 1; i <= n; i++)
    {
        fa[i] = i;
        s[i] = 1;
    }
    for (rint i = 1; i < n; i++)
    {
        int x = get(edge[i].x);
        int y = get(edge[i].y);
        if (x == y)
        {
            continue;
        }
        fa[x] = y;
        ans += (s[x] * s[y] - 1) * (edge[i].z + 1);
        s[y] += s[x];
    }
    return ans;
}

signed main()
{
    int T;
    scanf("%lld", &T);
    while (T--)
    {
        scanf("%lld", &n);

        for (rint i = 1; i < n; i++)
        {
            scanf("%lld%lld%lld", &edge[i].x, &edge[i].y, &edge[i].z);
        }

        ans = 0;

        printf("%lld\n", kruscal());
    }

    return 0;
}

7.拓扑排序

就是每次删掉一个入度为零的点形成的序列。

AcWing.848 有向图拓扑

板子题训练。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int h[N], ne[N], e[N], idx, in[N], n, m;
std::queue<int> q;
int ans[N];
int cnt = 0;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!in[i])
        {
            q.push(i);
        }
    }

    while (!q.empty())
    {
        int x = q.front();
        q.pop();

        ans[++cnt] = x;

        for (int i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            in[y]--;

            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }

    if (cnt == n)
    {
        return true;
    }

    return false;
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        in[b]++;
    }

    if (!toposort())
    {
        puts("-1");
        return 0;
    }

    for (rint i = 1; i <= cnt; i++)
    {
        printf("%d ", ans[i]);
    }

    return 0;
}

P4017 最大食物链计数

求从弱的动物到最强的动物有几条路可走。

d[x] += d[比 x 弱的]

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int mod = 80112002;

int h[N], ne[N], e[N], idx, in[N], n, m;
int d[N];
std::queue<int> q;
int ans[N];
int bd[N];
int cnt = 0;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!in[i])
        {
            q.push(i);
            d[i]++;
        }
    }

    while (!q.empty())
    {
        int x = q.front();
        q.pop();

        ans[++cnt] = x;

        for (int i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            d[y] = (d[y] + d[x]) % mod;
            in[y]--;

            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        in[b]++;
        bd[a]++; //判断有没有能吃他
    }

    toposort();
    int res = 0;

    for (rint i = 1; i <= n; i++)
    {
        if (bd[i] == 0)
        {
            res = (res + d[i]) % mod;
        }
    }

    printf("%d", res);

    return 0;
}

P1137 旅行计划

跟上面那个题差不多,加个 d[] 数组统计答案。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 2e5 + 5;

int h[N], ne[N], e[N], idx, in[N], n, m;
std::queue<int> q;
int ans[N];
int d[N];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        d[i] = 1;
        if (!in[i])
        {
            q.push(i);
        }
    }
    int cnt = 0;
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        ans[++cnt] = x;
        for (int i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            d[y] = d[x] + 1;
            in[y]--;
            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        in[b]++;
    }

    toposort();

    for (rint i = 1; i <= n; i++)
    {
        printf("%d\n", d[i]);
    }

    return 0;
}

AcWing.164 可达性统计

开个 f 数组统计答案,可以用 bitset 进行优化。

但是此题有一个很玄学的地方,就是后边遍历求集合答案的时候要倒着遍历,因为从顶点开始算。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <bitset>

#define rint register int
#define endl '\n'

const int N = 6e4 + 5;

int h[N], ne[N], e[N], idx, in[N], n, m;
std::queue<int> q;
int ans[N];
std::bitset<N> f[N];
int top;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!in[i])
        {
            q.push(i);
        }
    }
    int cnt = 0;
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        ans[++cnt] = x;
        for (int i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            in[y]--;
            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }
    top = cnt;
}

void calc()
{
    for (rint i = top; i != 0; i--)
    {
        int x = ans[i];
        f[x][x] = 1;
        for (rint j = h[x]; j; j = ne[j])
        {
            int y = e[j];
            f[x] |= f[y];
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        in[b]++;
    }

    toposort();
    calc();

    for (rint i = 1; i <= n; i++)
    {
        printf("%d\n", f[i].count());
    }

    return 0;
}

[NOIP2020] 排水系统

傻逼高精,此处用 __int128

仔细读题,发现有这么一句话:“该城市的排水系统设计科学,管道不会形成回路,即不会发生污水形成环流的情况。”

什么意思?这句话告诉我们这一定是一个 DAG 图。

直接拓扑排序就可以了。但是得写一个分数加法乘法,很麻烦啊。

PS:经本人亲测,如果不拓扑直接广搜不写 __int128 也能搞到 80pts。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <vector>

#define rint register int
#define endl '\n'
#define ll __int128

const int N = 1e6 + 5;

int h[N], ne[N], e[N], idx, in[N], n, m;
std::queue<int> q;
std::vector<int> ans;
__int128 d[N];

ll gcd(ll a, ll b) 
{ 
    return b ? gcd(b, a % b) : a; 
}
ll lcm(ll a, ll b) 
{ 
    return a / gcd(a, b) * b; 
}

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

struct node
{
    ll p, q;
    node()
    {
        p = 0, q = 1;
    }
    node operator*(const ll &s) const
    {
        node k;
        k.p = p;
        k.q = q * s;
        ll g = gcd(k.p, k.q);
        k.p /= g;
        k.q /= g;
        return k;
    }
    node operator+(const node &s) const
    {
        node k;
        k.q = lcm(q, s.q);
        k.p += p * (k.q / q);
        k.p += s.p * (k.q / s.q);
        ll g = gcd(k.p, k.q);
        k.p /= g;
        k.q /= g;
        return k;
    }
} val[N];

void toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!in[i])
        {
            q.push(i);
        }
    }
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        if (d[x])
        {
            val[x] = val[x] * d[x];
        }
        for (int i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            in[y]--;
            val[y] = val[x] + val[y];
            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }
}

void print(ll n)
{
    if (n > 9)
        print(n / 10);
    putchar(n % 10 + 48);
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i <= n; i++)
    {
        if (i <= m)
        {
            val[i].p = 1;
        }
        scanf("%d", &d[i]);
        if (!d[i])
        {
            ans.push_back(i);
        }
        for (rint j = 1; j <= d[i]; j++)
        {
            int x;
            scanf("%d", &x);
            add(i, x);
            in[x]++;
        }
    }
    
    toposort();
    
    for (rint i = 0; i < ans.size(); i++)
    {
        print(val[ans[i]].p);
        printf(" ");
        print(val[ans[i]].q);
        puts("");
    }

    return 0;
}

8.欧拉路径和欧拉回路

首先分别以模板题为例子。

【模板】欧拉路径

时空复杂度 \(O(m\)\(log\)\(m)\)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e5 + 5;
 
int n, m;
int idx;
int top, out[N], in[N];
int h[N], e[N], f[N], s[N], ne[N];

struct rec
{
    int u, v;
} a[N];      

bool cmp(rec a, rec b)
{
    if (a.u != b.u)
    {
        return a.u < b.u;		
    }
    return a.v > b.v;
}    

void add(int a, int b) 
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x) 
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
	if (!f[i]) 
        {
            f[i] = 1; 
            dfs(y);
        }		
    }
    s[++top] = x; 
}

int main()
{
    int l = 0, r = 0;
    int u = 1;
    // l 记录出度比入度大 1 的点的个数
    // r 记录入度比出度大 1 的点的个数
    // u 表示起点,初始为 1
    
    scanf("%d%d", &n, &m);
    
    for (rint i = 1; i <= m; i++)
    {
        scanf("%d%d", &a[i].u, &a[i].v);
        out[a[i].u]++;
        in[a[i].v]++;
    } 
    
    for (rint i = 1; i <= n; ++i)
    {
        if (out[i] - in[i] > 1 || in[i] - out[i] > 1) return puts("No"), 0;
	//如果入度与出度相差大于1,那么不存在欧拉路径
        if (out[i] - in[i] == 1) l++, u = i;
	//出度比入度大1,作为起点
        else if (in[i] - out[i] == 1) r++;
	//入度比出度大1
    }
    
    if ((l == 0 && r == 0) || (l == 1 && r == 1)) //有解
    {
        std::sort(a + 1, a + m + 1, cmp);
        for (rint i = 1; i <= m; i++)
        {
	    add(a[i].u, a[i].v); 
	}
        dfs(u);                  
        while (top != 0)
        {
            printf("%d ", s[top]);
            top--;
        }
        return 0;
    }
    puts("no");
    return 0;
}

对 dfs 优化后可以优化为 \(O(n + m)\)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e5 + 5;
 
int n, m;
int idx;
int top, out[N], in[N];
int h[N], e[N], s[N], ne[N];

struct rec
{
    int u, v;
} a[N];      

bool cmp(rec a, rec b)
{
    if (a.u != b.u)
    {
        return a.u < b.u;		
    }
    return a.v > b.v;
}    

void add(int a, int b) 
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x) 
{
    for (rint i = h[x]; i; i = h[x])
    {
        int y = e[i];
        h[x] = ne[i];
        dfs(y);
    }
    s[++top] = x; 
}

int main()
{
    int l = 0, r = 0;
    int u = 1;
    
    scanf("%d%d", &n, &m);
    
    for (rint i = 1; i <= m; i++)
    {
        scanf("%d%d", &a[i].u, &a[i].v);
        out[a[i].u]++;
        in[a[i].v]++;
    } 
    
    for (rint i = 1; i <= n; ++i)
    {
        if (out[i] - in[i] > 1 || in[i] - out[i] > 1) return puts("No"), 0;
        if (out[i] - in[i] == 1) l++, u = i;
        else if (in[i] - out[i] == 1) r++;
    }
    
    if ((l == 0 && r == 0) || (l == 1 && r == 1)) 
    {
        std::sort(a + 1, a + m + 1, cmp);
        for (rint i = 1; i <= m; i++)
        {
	    add(a[i].u, a[i].v); 
	}
        dfs(u);                  
        while (top != 0)
        {
            printf("%d ", s[top]);
            top--;
        }
        return 0;
    }
    puts("no");
    return 0;
}

AcWing.1184 欧拉回路

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int M = 4e6 + 5;

int h[N], e[M], ne[M], idx = 1;
int t, n, m;
int cnt;
int in[N], out[N], ans[M];
bool vis[M];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool check()
{
    for (rint i = 1; i <= n; i++)
    {
        if (t == 1 && (in[i] + out[i]) & 1)
            return 0;
        if (t == 2 && in[i] != out[i])
            return 0;
    }
    return 1;
}

void dfs(int x)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        if (vis[i])
        {
            continue;
        }
        vis[i] = 1;
        if (t == 1)
        {
            vis[i ^ 1] = 1;
        }
        int val = t == 1 ? i / 2 : i - 1;
        if (t == 1 && (i & 1))
        {
            val = -val;
        }
        int y = e[i];
        dfs(y);
        ans[++cnt] = val;
    }
}

int main()
{
    scanf("%d%d%d", &t, &n, &m);

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        if (t == 1)
        {
            add(b, a);
        }
        in[b]++;
        out[a]++;
    }

    if (check() == false)
    {
        return puts("NO"), 0;
    }

    for (rint i = 1; i <= n; i++)
    {
        if (h[i])
        {
            dfs(i);
            break;
        }
    }

    if (cnt != m)
    {
        return puts("NO"), 0;
    }

    puts("YES");
    for (rint i = cnt; i != 0; i--)
    {
        printf("%d ", ans[i]);
    }

    return 0;
}

[USACO] 骑马修栅栏

欧拉回路计数问题,但是开临界矩阵更方便一些。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5;
const int M = 1e5 + 5;

int n;
int g[N][N];
int d[M];
int ans[M];
int cnt;

void dfs(int x)
{
    for (rint i = 1; i <= 500; i++)
    {
        if (g[x][i])
        {
            g[x][i]--;
            g[i][x]--;
            dfs(i);
        }
    }
    ans[++cnt] = x;
}

int main()
{
    scanf("%d", &n);
    int s = 1;

    for (rint i = 1; i <= n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        g[a][b]++;
        g[b][a]++;
        d[a]++;
        d[b]++;
    }

    for (rint i = 1; i <= 500; i++)
    {
        if (d[i] % 2 == 1)
        {
            s = i;
            break;
        }
    }

    dfs(s);

    for (rint i = cnt; i >= 1; i--)
    {
        printf("%d\n", ans[i]);
    }

    return 0;
}

P1341 无序字母对

非常板子,和上面那道题基本一样,但是有个很致命的问题,就是数据水了,正常来说要加一个并查集,但是我懒得写了,就在最后加了个特判。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 3e2 + 5;
const int M = 1e5 + 5;

int n;
int cnt;
int g[N][N], d[N];
int s;
char res[M];

void dfs(int x)
{
    for (rint i = 1; i <= 300; i++)
    {
        if (g[x][i])
        {
            g[x][i]--;
            g[i][x]--;
            dfs(i);
        }
    }
    res[n--] = x;
}

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        char a, b;
        std::cin >> a >> b;
        g[a][b]++;
        g[b][a]++;
        d[a]++;
        d[b]++;
    }

    for (rint i = 1; i <= 300; i++)
    {
        if (d[i] % 2 == 1)
        {
            cnt++;
        }
    }

    //找度数为奇数的点
    for (rint i = 1; i <= 300; i++)
    {
        if (d[i] % 2 == 1)
        {
            s = i;
            break;
        }
    }

    //找不到奇点,就是另外找点
    if (s == 0)
    {
        for (rint i = 1; i <= 300; i++)
        {
            if (d[i])
            {
                s = i;
                break;
            }
        }
    }

    if (cnt && cnt != 2)
    {
        puts("No Solution");
        return 0;
    }

    dfs(s);

    if (n >= 0)
    {
        puts("No Solution");
        return 0;
    }

    std::cout << res;

    return 0;
}

P1127 词链

也是一个偏板子的一道题目。

重点是如何将其转化为欧拉回路,同时此题是带边权的欧拉回路。

对每一个单词连一条从首字母指向尾字母的有向边,所建成的图中一定存在欧拉通路或者欧拉回路。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e3 + 5;

std::string str[N], ans[N];
int n;
int idx, h[30], ne[N], e[N], w[N], d[N];
int cnt;
bool v[N];

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x)
{
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        int z = w[i];
        if (!v[z])
        {
            v[z] = 1;
            dfs(y);
            ans[++cnt] = str[z];
        }
    }
}

int main()
{
    int s;
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        std::cin >> str[i];
    }

    std::sort(str + 1, str + 1 + n);

    s = str[1][0] - 'a' + 1;

    for (rint i = n; i >= 1; i--)
    {
        int a, b;
        a = str[i][0] - 'a' + 1;
        b = str[i][str[i].size() - 1] - 'a' + 1;
        add(a, b, i);
        d[a]++;
        d[b]++;
    }

    for (rint i = 1; i <= 26; i++)
    {
        if (d[i] % 2)
        {
            s = i;
            break;
        }
    }

    dfs(s);

    for (rint i = 1; i <= n; i++)
    {
        if (!v[i])
        {
            return puts("***"), 0;
        }
    }

    for (rint i = cnt; i > 1; i--)
    {
        std::cout << ans[i] << ".";
    }

    std::cout << ans[1];

    return 0;
}

9.二分图的判定

AcWing.860 染色法判二分图

板子题,染色法即可。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int M = 4e5 + 5;

int n, m;
int idx, h[N], e[M], ne[M];
int color[N];
int ans;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool dfs(int x, int c)
{
    color[x] = c;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!color[y] && !dfs(y, 3 - c))
            return 0;
        else if (color[y] == c)
            return 0;
    }
    return 1;
}

bool check()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!color[i] && !dfs(i, 1))
        {
            return 0;
        }
    }
    return 1;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    if (!check())
    {
        return puts("No"), 0;
    }

    puts("Yes");

    return 0;
}

P1330 封锁阳光大学

加个 a[] 数组记录答案即可。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int M = 4e5 + 5;

int n, m;
int idx, h[N], e[M], ne[M];
int color[N];
int ans;
int a[3];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool dfs(int x, int c)
{
    color[x] = c;
    a[c]++;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!color[y] && !dfs(y, 3 - c))
            return 0;
        else if (color[y] == c)
            return 0;
    }
    return 1;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    for (rint i = 1; i <= n; i++)
    {
        if (!color[i])
        {
            if (!dfs(i, 1))
            {
                return puts("Impossible"), 0;
            }
            ans += std::min(a[1], a[2]);
            a[1] = a[2] = 0;
        }
    }

    printf("%d", ans);

    return 0;
}

AcWing.695 劣马

由于名字是字符串,用 map 记录答案即可。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>
#include <string>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 2e2 + 5;
const int M = 4e2 + 5;

int n;
int idx, h[N], e[M], ne[M];
int color[N];
int ans;
int a[2];
std::map<std::string, int> d;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

bool dfs(int x, int c)
{
    color[x] = c;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!color[y] && !dfs(y, 3 - c))
            return 0;
        else if (color[y] == c)
            return 0;
    }
    return 1;
}

bool check()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!color[i] && !dfs(i, 1))
        {
            return 0;
        }
    }
    return 1;
}

void clear()
{
    memset(color, 0, sizeof color);
    memset(h, 0, sizeof h);
    idx = 0;
    d.clear();
}

int main()
{
    int T;
    scanf("%d", &T);
    int rk = 0;
    while (T--)
    {
        clear();
        scanf("%d", &n);
        int cnt = 0;
        for (rint i = 1; i <= n; i++)
        {
            std::string a, b;
            std::cin >> a >> b;
            if (!d.count(a)) d[a] = ++cnt;
            if (!d.count(b)) d[b] = ++cnt;
            add(d[a], d[b]);
        }
        printf("Case #%d: ", ++rk);

        if (!check())
            puts("No");
        else
            puts("Yes");
    }

    return 0;
}

10.最近公共祖先

【模板】最近公共祖先

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;
const int M = 1e6 + 5;

int idx, ne[M], h[N], e[M];
int n, m, root;
int fa[N][21], d[N];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x)
{
    d[x] = d[fa[x][0]] + 1;
    for (rint i = 0; fa[x][i]; i++)
    {
        fa[x][i + 1] = fa[fa[x][i]][i];
    }
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (d[y])
        {
            continue;
        }
        fa[y][0] = x;
        dfs(y);
    }
}

int lca(int x, int y)
{
    if (d[x] < d[y])
    {
        std::swap(x, y);
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (d[fa[x][i]] >= d[y])
        {
            x = fa[x][i];
        }
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (fa[x][i] != fa[y][i])
        {
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    if (x == y)
    {
        return x;
    }
    return fa[x][0];
}

void init()
{
    dfs(root);
    for (rint j = 1; j <= 20; j++)
    {
        for (rint i = 1; i <= n; i++)
        {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    }
}

int main()
{
    scanf("%d%d%d", &n, &m, &root);

    for (rint i = 1; i < n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    init();

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        printf("%d\n", lca(a, b));
    }

    return 0;
}

AcWing.1171 距离

带边权 LCA 板子题

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 2e4 + 5;
const int M = 4e4 + 1;

int n, m, root, cnt;
int d[N], fa[N][21];
int h[N], e[M], ne[M], idx, w[M];
int dist[N];

void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x)
{
    d[x] = d[fa[x][0]] + 1;
    for (rint i = 0; fa[x][i]; i++)
    {
        fa[x][i + 1] = fa[fa[x][i]][i];
    }
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        int z = w[i];
        if (d[y])
        {
            continue;
        }
        dist[y] = dist[x] + z;
        fa[y][0] = x;
        dfs(y);
    }
}

int lca(int x, int y)
{
    if (d[x] < d[y])
    {
        std::swap(x, y);
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (d[fa[x][i]] >= d[y])
        {
            x = fa[x][i];
        }
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (fa[x][i] != fa[y][i])
        {
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    if (x == y)
    {
        return x;
    }
    return fa[x][0];
}

void init()
{
    root = 1;
    dfs(root);
    for (rint j = 1; j <= 20; j++)
    {
        for (rint i = 1; i <= n; i++)
        {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i < n; i++)
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        add(a, b, c);
        add(b, a, c);
    }

    init();

    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        printf("%d\n", dist[a] + dist[b] - 2 * dist[lca(a, b)]);
    }

    return 0;
}

P3398 仓鼠找 sugar

如果两条路径相交,那么一定有一条路径的 LCA 在另一条路径上

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;
const int M = 1e6 + 5;

int idx, ne[M], h[N], e[M];
int n, m, root;
int fa[N][21], d[N];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs(int x)
{
    d[x] = d[fa[x][0]] + 1;
    for (rint i = 0; fa[x][i]; i++)
    {
        fa[x][i + 1] = fa[fa[x][i]][i];
    }
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (d[y])
        {
            continue;
        }
        fa[y][0] = x;
        dfs(y);
    }
}

int lca(int x, int y)
{
    if (d[x] < d[y])
    {
        std::swap(x, y);
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (d[fa[x][i]] >= d[y])
        {
            x = fa[x][i];
        }
    }
    for (rint i = 20; i >= 0; i--)
    {
        if (fa[x][i] != fa[y][i])
        {
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    if (x == y)
    {
        return x;
    }
    return fa[x][0];
}

void init()
{
    root = 1;
	dfs(root);
    for (rint j = 1; j <= 20; j++)
    {
        for (rint i = 1; i <= n; i++)
        {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);

    for (rint i = 1; i < n; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    init();

    for (rint i = 1; i <= m; i++)
    {
        int x1, y1, x2, y2;
        scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
        int ans;

        ans = std::max(d[lca(x1, x2)], d[lca(x1, y2)]);
        ans = std::max(ans, d[lca(y1, x2)]);
        ans = std::max(ans, d[lca(y1, y2)]);

        if (ans >= d[lca(x1, y1)] && ans >= d[lca(x2, y2)])
        {
            puts("Y");
        }
        else
        {
            puts("N");
        }
    }

    return 0;
}

11.Tarjan

【模板】割点

在无向连通图中,如果将其中一个点以及所有连接该点的边去掉,图就不再连通,那么这个点就叫做割点。

这个就是一个非常基础的模板题。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;
const int M = 2e5 + 5;

int h[N], e[M], ne[M], idx;
int dfn[N], low[N];
int n, m, times, root;
bool cut[N];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void tarjan(int x)
{
    dfn[x] = low[x] = ++times;
    int flag = 0;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!dfn[y])
        {
            tarjan(y);
            low[x] = std::min(low[x], low[y]);
            if (low[y] >= dfn[x])
            {
                flag++;
                if (x != root || flag > 1)
                {
                    cut[x] = 1;
                }
            }
        }
        else
        {
            low[x] = std::min(low[x], dfn[y]);
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    
    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        if (a == b)
        {
            continue;
        }
        add(a, b);
        add(b, a);
    }
    for (rint i = 1; i <= n; i++)
    {
        if (!dfn[i])
        {
            root = i;
            tarjan(i);
        }
    }

    int cnt = 0;
    for (rint i = 1; i <= n; i++)
    {
        if (cut[i])
        {
            cnt++;
        }
    }

    printf("%d\n", cnt);

    for (rint i = 1; i <= n; i++)
    {
        if (cut[i])
        {
            printf("%d ", i);
        }
    }

    return 0;
}

【模板】缩点

此题在缩点后还得求路径,tarjan + toposort 即可

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;
const int M = 2e6 + 5;

int n, m, times, top, val[N];
int num[N], dfn[N], low[N], stk[N];
bool vis[N];
int in[N], dist[N];
int idx, h[N], e[N], ne[M];
int from[M], h1[N];
// num 表示所属连通块的编号
std::queue<int> q;
int ans;

void add(int a, int b)
{
    e[++idx] = b, from[idx] = a, ne[idx] = h[a], h[a] = idx;
}

void tarjan(int x)
{
    dfn[x] = low[x] = ++times;
    stk[++top] = x;
    vis[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!dfn[y])
        {
            tarjan(y);
            low[x] = min(low[x], low[y]);
        }
        else if (vis[y])
        {
            low[x] = min(low[x], low[y]);
        }
    }
    if (dfn[x] == low[x])
    {
        int y;
        while (y = stk[top--])
        {
            num[y] = x;
            vis[y] = 0;
            if (x == y)
            {
                break;
            }
            val[x] += val[y];
        }
    }
}

void toposort()
{
    for (rint i = 1; i <= n; i++)
    {
        if (!in[i] && num[i] == i)
        {
            q.push(i);
            dist[i] = val[i];
        }
    }
    while (!q.empty())
    {
        int x = q.front();
        q.pop();
        for (rint i = h1[x]; i; i = ne[i])
        {
            int y = e[i];
            in[y]--;
            dist[y] = max(dist[y], dist[x] + val[y]);
            if (in[y] == 0)
            {
                q.push(y);
            }
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        ans = max(ans, dist[i]);
    }
}

int main()
{
    cin >> n >> m;
    for (rint i = 1; i <= n; i++)
    {
        cin >> val[i];
    }
    for (rint i = 1; i <= m; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b);
    }

    for (rint i = 1; i <= n; i++)
    {
        if (!dfn[i])
        {
            tarjan(i);
        }
    }
    int cnt = 0;
    for (rint i = 1; i <= m; i++)
    {
        int y = e[i];
        int z = from[i];
        int u = num[z];
        int v = num[y];
        if (u != v)
        {
            ne[++cnt] = h1[u], e[cnt] = v, from[cnt] = u, h1[u] = cnt;
            in[v]++;
        }
    }

    toposort();
    printf("%d\n", ans);

    return 0;
}

[国家集训队] 稳定婚姻

开个 map 映射存储答案即可。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>
#include <string>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;
const int M = 2e6 + 5;

int n, m, times, top;
int num[N], dfn[N], low[N], stk[N];
bool vis[N];
int idx, h[N], e[N], ne[M];
std::map<std::string, int> s;

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void tarjan(int x)
{
    dfn[x] = low[x] = ++times;
    stk[++top] = x;
    vis[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!dfn[y])
        {
            tarjan(y);
            low[x] = min(low[x], low[y]);
        }
        else if (vis[y])
        {
            low[x] = min(low[x], low[y]);
        }
    }
    if (dfn[x] == low[x])
    {
        int y;
        while (y = stk[top--])
        {
            num[y] = x;
            vis[y] = 0;
            if (x == y)
            {
                break;
            }
        }
    }
}

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        std::string a, b; // a is girl, b is boy
        std::cin >> a >> b;
        s[a] = i;
        s[b] = i + n;
        add(i, i + n);
    }

    scanf("%d", &m);

    for (rint i = 1; i <= m; i++)
    {
        std::string a, b;
        std::cin >> a >> b;
        add(s[b], s[a]);
    }

    for (rint i = 1; i <= n * 2; i++)
    {
        if (!dfn[i])
        {
            tarjan(i);
        }
    }

    for (rint i = 1; i <= n; i++)
    {
        if (num[i] == num[i + n])
            puts("Unsafe");
        else
            puts("Safe");
    }

    return 0;
}

[GZOI2017] 小z玩游戏

看的第一篇题解。

若当前兴奋值为 \(a\),某个游戏有趣程度为 \(w\),兴奋程度为 \(k\)\(w\)\(a\) 的倍数。

可以视为兴奋值先从 \(a\) 变成了它的倍数 \(w\),再变成 \(k\)

那我们就可以对每个数向它的倍数建边,再对于每个游戏从 \(w\)\(k\) 建边。

跑一遍 Tarjan,对于每个游戏判断 \(w\)\(k\) 是否在同一强连通分量。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 3e6 + 5;
const int M = 6e6 + 5;

int n, m, times, top;
int num[N], dfn[N], low[N], stk[N];
bool vis[N];
int idx, h[N], e[N], ne[M];
int a1[N], a2[N];

void add(int a, int b)
{
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

void tarjan(int x)
{
    dfn[x] = low[x] = ++times;
    stk[++top] = x;
    vis[x] = 1;
    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (!dfn[y])
        {
            tarjan(y);
            low[x] = min(low[x], low[y]);
        }
        else if (vis[y])
        {
            low[x] = min(low[x], low[y]);
        }
    }
    if (dfn[x] == low[x])
    {
        int y;
        while (y = stk[top--])
        {
            num[y] = x;
            vis[y] = 0;
            if (x == y)
            {
                break;
            }
        }
    }
}

void clear()
{
    memset(h, 0, sizeof h);
    memset(low, 0, sizeof low);
    memset(dfn, 0, sizeof dfn);
    idx = 0;
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        clear();
        int n;
        scanf("%d", &n);
        int m = 0;
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &a1[i]);
            m = std::max(m, a1[i]);
        }
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &a2[i]);
            add(a1[i], a2[i]);
        }
        for (rint i = 1; i + i <= m; i++)
        {
            for (rint j = 2; j * i <= m; j++)
            {
                add(i, i * j);
            }
        }
        for (rint i = 1; i <= n; i++)
        {
            if (!dfn[i])
            {
                tarjan(i);
            }
        }
        int ans = 0;
        for (rint i = 1; i <= n; i++)
        {
            if (num[a1[i]] == num[a2[i]])
            {
                ans++;
            }
        }
        printf("%d\n", ans);
    }

    return 0;
}

12. 网络流

正文开始之前强调一件事儿,注意:

在多次使用 dinic 的时候一定要注意清空 maxflow

【模板】网络最大流

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define int long long

const int inf = 1e18;
const int N = 1e4 + 5;
const int M = 2e5 + 5;

int h[N], f[M], e[M], ne[M], d[N], idx = 1;
int n, m, S, T, flow, maxflow;

void add(int a, int b, int c)
{
    e[++idx] = b, f[idx] = c, ne[idx] = h[a], h[a] = idx;
    e[++idx] = a, f[idx] = 0, ne[idx] = h[b], h[b] = idx;
}

bool bfs()
{
    memset(d, 0, sizeof d);
    std::queue<int> q;
    q.push(S);
    d[S] = 1;
    while (q.size())
    {
        int x = q.front();
        q.pop();
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            if (f[i] && !d[y])
            {
                q.push(y);
                d[y] = d[x] + 1;
                if (y == T)
                {
                    return true;
                }
            }
        }
    }
    return false;
}

int find(int x, int flow)
{
    if (x == T)
    {
        return flow;
    }
    int k, rest = flow;
    for (rint i = h[x]; i && rest; i = ne[i])
    {
        int y = e[i];
        if (f[i] && d[y] == d[x] + 1)
        {
            k = find(y, std::min(rest, f[i]));
            if (k == 0)
            {
                d[y] = 0;
            }
            f[i] -= k;
            f[i ^ 1] += k;
            rest -= k;
        }
    }
    return flow - rest;
}

void dinic()
{
    while (bfs())
    {
        while ((flow = find(S, inf)))
        {
            maxflow += flow;
        }
    }
}

signed main()
{
    scanf("%lld%lld%lld%lld", &n, &m, &S, &T);
    for (rint i = 1; i <= m; i++)
    {
        int a, b, c;
        scanf("%lld%lld%lld", &a, &b, &c);
        add(a, b, c);
    }
    dinic();
    printf("%lld", maxflow);
    return 0;
}

【模板】最小费用最大流

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

#define rint register int
#define endl '\n'

const int inf = 1e9;
const int N = 5e3 + 5;
const int M = 1e5 + 5;

int head[N], ne[M], f[M], e[M], w[M], idx = 1;
int dist[N], h[N], pre[N], n, m, s, t;
bool v[N];
int flow = inf, ans, maxflow;
std::priority_queue<std::pair<int, int>> q;

void add(int a, int b, int c, int d)
{
    e[++idx] = b, f[idx] = c, w[idx] = d, ne[idx] = head[a], head[a] = idx;
    e[++idx] = a, f[idx] = 0, w[idx] = -d, ne[idx] = head[b], head[b] = idx;
}

bool dijkstra()
{
    for (rint i = 1; i <= n; i++)
    {
	dist[i] = inf;
    }
    memset(v, 0, sizeof v);
    q.push(std::make_pair(0, s));
    dist[s] = 0;
    while (q.size())
    {
        int x = q.top().second;
        q.pop();
        if (v[x])
        {
            continue;
        }
        v[x] = 1;
        for (rint i = head[x]; i; i = ne[i])
        {
            if (f[i] <= 0)
            {
                continue;
            }
            int y = e[i];
            int z = w[i];
            if (dist[y] > dist[x] + z + h[x] - h[y])
            {
                dist[y] = dist[x] + z + h[x] - h[y];
                pre[y] = i;
                q.push(std::make_pair(-dist[y], y));
            }
        }
    }
    return dist[t] < inf;
}

void MCMA()
{
    int x = t;
    int k = flow;

    while (x != s)
    {
        int i = pre[x];
        k = std::min(k, f[i]);
        x = e[i ^ 1];
    }

    x = t;

    while (x != s)
    {
        int i = pre[x];
        f[i] -= k;
        f[i ^ 1] += k;
        x = e[i ^ 1];
    }

    for (rint i = 1; i <= n; i++)
    {
        h[i] += dist[i];
    }

    flow -= k;
    maxflow += k;
    ans += k * h[t];
}

signed main()
{
    scanf("%d%d%d%d", &n, &m, &s, &t);

    for (rint i = 1; i <= m; i++)
    {
        int a, b, c, d;
        scanf("%d%d%d%d", &a, &b, &c, &d);
        add(a, b, c, d);
    }

    while (dijkstra())
    {
        if (flow)
            MCMA();
        else
            break;
    }

    printf("%d %d", maxflow, ans);

    return 0;
}

[CTSC1999]家园

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstring>
#include <queue>

#define rint register int
#define endl '\n'

const int N = 1101 * 50 + 5;
const int M = (N + 1100 + 20 * 1101) + 5;
const int inf = 1e8;

int n, m, S, k, T, flow, maxflow;
int h[N], e[M], f[M], ne[M], idx = 1;
int d[N];
int fa[30];

struct Ship
{
    int h; //容量
    int r;
    int id[30];
} ships[30];

int get(int x)
{
    return fa[x] == x ? x : get(fa[x]);
}

//因为包括地球和月球,所以总共n+2个点
int calc(int i, int day)
{
    return day * (n + 2) + i;
}

bool check(int n)
{
    int begin = 0;
    if (get(begin) == get(n + 1))
        return 1;
    return 0;
}

void add(int a, int b, int c)
{
    e[++idx] = b, f[idx] = c, ne[idx] = h[a], h[a] = idx;
    e[++idx] = a, f[idx] = 0, ne[idx] = h[b], h[b] = idx;
}

bool bfs()
{
    memset(d, 0, sizeof d);
    std::queue<int> q;
    q.push(S);
    d[S] = 1;
    while (q.size())
    {
        int x = q.front();
        q.pop();
        for (rint i = h[x]; i; i = ne[i])
        {
            int y = e[i];
            if (f[i] && !d[y])
            {
                q.push(y);
                d[y] = d[x] + 1;
                if (y == T)
                {
                    return true;
                }
            }
        }
    }
    return false;
}

int find(int x, int flow)
{
    if (x == T)
    {
        return flow;
    }
    int k, rest = flow;
    for (rint i = h[x]; i && rest; i = ne[i])
    {
        int y = e[i];
        if (f[i] && d[y] == d[x] + 1)
        {
            k = find(y, std::min(rest, f[i]));
            if (k == 0)
            {
                d[y] = 0;
            }
            f[i] -= k;
            f[i ^ 1] += k;
            rest -= k;
        }
    }
    return flow - rest;
}

int dinic()
{
    maxflow = 0;
    while (bfs())
    {
        while ((flow = find(S, inf)))
        {
            maxflow += flow;
        }
    }
    return maxflow;
}

int main()
{
    scanf("%d%d%d", &n, &m, &k);
    S = N - 2;
    T = N - 1;

    for (rint i = 0; i < 30; i++)
    {
        fa[i] = i;
    }

    for (rint i = 0; i < m; i++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        ships[i].h = a;
        ships[i].r = b;

        for (rint j = 0; j < b; j++)
        {
            int id;
            scanf("%d", &id);
            if (id == -1)
            {
                id = n + 1;
            }
            ships[i].id[j] = id;
            if (j != 0)
            {
                int x = ships[i].id[j - 1];
                fa[get(x)] = get(id);
            }
        }
    }

    if (!check(n))
    {
        return puts("0"), 0;
    }

    add(S, calc(0, 0), k);
    add(calc(n + 1, 0), T, inf);

    int day = 1;
    int res = 0;

    while (1)
    {
        add(calc(n + 1, day), T, inf);
        for (rint i = 0; i <= n + 1; i++)
        {
            add(calc(i, day - 1), calc(i, day), inf);
        }
        for (rint i = 0; i < m; i++)
        {
            int r = ships[i].r;
            int a = ships[i].id[(day - 1) % r];
            int b = ships[i].id[day % r];
            add(calc(a, day - 1), calc(b, day), ships[i].h);
        }
        res += dinic();
        if (res >= k)
        {
            break;
        }
        day++;
    }

    printf("%d\n", day);

    return 0;
}

13.高级数据结构

Part1.平衡树

【模板】普通平衡树

这玩意儿考了也是以模板的形式出现,会个 pb_ds 就行了。

//先放一波平板电视

#include <bits/stdc++.h>
#include <bits/extc++.h>

typedef long long ll;

const int N = 2e5 + 7, M = 1e5 + 7, INF = 0x3f3f3f3f;

using namespace std;
using namespace __gnu_pbds;

tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> tr;
int n, m;
ll k, ans;

int main()
{
    cin >> n;
    for (int i = 1; i <= n; ++i)
    {
        int op;
        cin >> op >> k;

        if (op == 1)
        {
            tr.insert((k << 20) + i);
        }

        if (op == 2)
        {
            tr.erase(tr.lower_bound(k << 20));
        }

        if (op == 3)
        {
            printf("%d\n", tr.order_of_key(k << 20) + 1);
        }

        if (op == 4)
        {
            ans = *tr.find_by_order(k - 1);
            printf("%lld\n", ans >> 20);
        }

        if (op == 5)
        {
            ans = *--tr.lower_bound(k << 20);
            printf("%lld\n", ans >> 20);
        }

        if (op == 6)
        {
            ans = *tr.upper_bound((k << 20) + n);
            printf("%lld\n", ans >> 20);
        }
    }
    return 0;
}

Part2-1.树状数组模板

【模板】树状数组 1

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e6 + 5;

int n, m;
int c[N];

int lowbit(int x)
{
    return x & -x;
}

int ask(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
    }
    return ans;
}

void add(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] += y;
    }
}

signed main()
{
    scanf("%d%d", &n, &m);
    for (rint i = 1; i <= n; i++)
    {
        int a;
        scanf("%d", &a);
        add(i, a);
    }
    for (rint i = 1; i <= m; i++)
    {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);
        if (op == 1)
        {
            add(l, r);
        }
        if (op == 2)
        {
            printf("%d\n", ask(r) - ask(l - 1));
        }
    }
    return 0;
}

【模板】树状数组 2

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e6 + 5;

int n, m;
int c[N];
int w[N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] += y;
    }
}

int search(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
    }
    return ans;
}

signed main()
{
    scanf("%d%d", &n, &m);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &w[i]);
    }
    for (rint i = 1; i <= m; i++)
    {
        int op;
        scanf("%d", &op);
        if (op == 1)
        {
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            add(a, c);
            add(b + 1, -c);
        }
        if (op == 2)
        {
            int a;
            scanf("%d", &a);
            printf("%d\n", w[a] + search(a));
        }
    }

    return 0;
}

S2OJ.#307 二维树状数组 1

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

typedef long long ll;

const int N = 5e3 + 5;

int n, m;
ll c[N][N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y, int v)
{
    for (; x <= n; x += lowbit(x))
    {
        for (rint k = y; k <= m; k += lowbit(k))
        {
            c[x][k] += v;
        }
    }
}

ll ask(int x, int y)
{
    ll ans = 0;
    for (; x; x -= lowbit(x))
    {
        for (rint k = y; k; k -= lowbit(k))
        {
            ans += c[x][k];
        }
    }
    return ans;
}

int main()
{
    scanf("%d%d", &n, &m);
    int op;
    while (~scanf("%d", &op))
    {
        if (op == 1)
        {
            int a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            add(a, b, c);
        }
        if (op == 2)
        {
            int a, b, c, d;
            scanf("%d%d%d%d", &a, &b, &c, &d);
            printf("%lld\n", ask(a - 1, b - 1) - ask(a - 1, d) - ask(c, b - 1) + ask(c, d));
        }
    }
    return 0;
}

S2OJ.#308 二维树状数组 2

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

typedef long long ll;

const int N = 5e3 + 5;

int n, m;
ll c[N][N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y, int v)
{
    for (; x <= n; x += lowbit(x))
    {
        for (rint k = y; k <= m; k += lowbit(k))
        {
            c[x][k] += v;
        }
    }
}

ll ask(int x, int y)
{
    ll ans = 0;
    for (; x; x -= lowbit(x))
    {
        for (rint k = y; k; k -= lowbit(k))
        {
            ans += c[x][k];
        }
    }
    return ans;
}

int main()
{
    scanf("%d%d", &n, &m);
    int op;
    while (~scanf("%d", &op))
    {
        if (op == 1)
        {
            int a, b, c, d, x;
            scanf("%d%d%d%d%d", &a, &b, &c, &d, &x);
            add(a, b, x);
            add(a, d + 1, -x);
            add(c + 1, b, -x);
            add(c + 1, d + 1, x);
        }
        if (op == 2)
        {
            int a, b;
            scanf("%d%d", &a, &b);
            printf("%lld\n", ask(a, b));
        }
    }
    return 0;
}

Part2-2.树状数组实战

首先是逆序对的问题

P1908 逆序对

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'
#define int long long

const int N = 5e5 + 5;

struct node
{
    int val;
    int num;
    bool friend operator<(node a, node b)
    {
        if (a.val == b.val)
            return a.num < b.num;
        return a.val < b.val;
    }
} a[N];

int c[N];
int n;
int b[N];
int cnt;
int rank[N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] += y;
    }
}

int search(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
    }
    return ans;
}

signed main()
{
    scanf("%lld", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i].val);
        a[i].num = i;
    }
    std::sort(a + 1, a + n + 1);
    for (rint i = 1; i <= n; i++)
    {
        rank[a[i].num] = i;
    }
    for (rint i = 1; i <= n; i++)
    {
        b[i] = i - 1 - search(rank[i]);
        cnt += b[i];
        add(rank[i], 1);
    }
    printf("%lld", cnt);

    return 0;
}

[USACO] Haircut

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'
#define int long long

const int N = 2e5 + 5;

int c[N];
int n;
int a[N], b[N];
int cnt;
int h[N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] += y;
    }
}

int search(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
    }
    return ans;
}

signed main()
{
    scanf("%lld", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
        a[i]++; //输入数据中存在零
        h[a[i]] += i - 1 - search(a[i]);
        //存储大于 a[i] 的数有几个
        add(a[i], 1);
    }
    int ans = 0;
    for (rint i = 0; i <= n - 1; i++)
    {
        ans += h[i];
        printf("%lld\n", ans);
    }

    return 0;
}

[NOIO-S 2020] 冒泡排序

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'
#define int long long

const int N = 2e5 + 5;

int c[N];
int n, m;
int a[N], b[N];
int cnt;
int h[N]; 

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c[x] += y;
    }
}

int search(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
    }
    return ans;
}

signed main()
{
    scanf("%lld%lld", &n, &m);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
        b[i] = i - 1 - search(a[i]);
        cnt += b[i];
        h[b[i]]++;
        add(a[i], 1);
    }
    memset(c, 0, sizeof c);
    add(1, cnt);
    int ret = 0;
    for (rint i = 1; i <= n; i++)
    {
        ret += h[i - 1];
        add(i + 1, -(n - ret));
    }
    for (rint i = 1; i <= m; i++)
    {
        int op, x;
        scanf("%lld%lld", &op, &x);
        x = std::min(x, n - 1);
        if (op == 1)
        {
            if (a[x] < a[x + 1])
            {
                std::swap(a[x], a[x + 1]);
                std::swap(b[x], b[x + 1]);
                add(1, 1);
                add(b[x + 1] + 2, -1);
                b[x + 1]++;
            }
            else
            {
                std::swap(a[x], a[x + 1]);
                std::swap(b[x], b[x + 1]);
                add(1, -1);
                b[x]--;
                add(b[x] + 2, 1);
            }
        }
        else
        {
            printf("%lld\n", search(x + 1));
        }
    }

    return 0;
}

前缀和查询的问题

P3801 红色的幻想乡

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 4e5 + 5;

int n, m, q;
int c1[N], c2[N];
int a[N], b[N];

int lowbit(int x)
{
    return x & -x;
}

void add1(int x, int y)
{
    for (; x <= n; x += lowbit(x))
    {
        c1[x] += y;
    }
}

int search1(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c1[x];
    }
    return ans;
}

void add2(int x, int y)
{
    for (; x <= m; x += lowbit(x))
    //这个点得注意一下
    {
        c2[x] += y;
    }
}

int search2(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c2[x];
    }
    return ans;
}

signed main()
{
    scanf("%lld%lld%lld", &n, &m, &q);
    while (q--)
    {
        int op, x1, y1, x2, y2;
        scanf("%lld", &op);
        if (op == 1)
        {
            scanf("%lld%lld", &x1, &y1);
            if (a[x1] == 1)
                add1(x1, -1);
            else
                add1(x1, 1);
            a[x1] ^= 1;
            //如果是 1 就变成 0, 如果是 0 就变成 1

            if (b[y1] == 1)
                add2(y1, -1);
            else
                add2(y1, 1);
            b[y1] ^= 1;
        }
        else
        {
            scanf("%lld%lld%lld%lld", &x1, &y1, &x2, &y2);
            //纵坐标之差乘横着的迷雾条数加横坐标之差乘纵着的迷雾条数减横着的迷雾条数乘纵着的迷雾条数乘二
            printf("%lld\n", (x2 - x1 + 1) * (search2(y2) - search2(y1 - 1)) + (y2 - y1 + 1) * (search1(x2) - search1(x1 - 1)) - (search2(y2) - search2(y1 - 1)) * (search1(x2) - search1(x1 - 1)) * 2);
        }
    }
    return 0;
}

[USACO] Generic Cow Protests

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 1e5 + 5;
const int mod = 1e9 + 9;

int n, m, q;
int c[N];
int a[N], b[N];
int p[N];
int s[N];
int ans;

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int y)
{
    for (; x <= n + 1; x += lowbit(x))
    {
        c[x] += y;
        c[x] %= mod;
    }
}

int search(int x)
{
    int ans = 0;
    for (; x; x -= lowbit(x))
    {
        ans += c[x];
        ans %= mod;
    }
    return ans;
}

signed main()
{
    scanf("%lld", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
        s[i] = s[i - 1] + a[i];
        a[i] = s[i];
    }

    std::sort(a + 1, a + n + 1);
    for (rint i = 0; i <= n; i++)
    {
        s[i] = std::lower_bound(a + 1, a + n + 1, s[i]) - a + 1;
    }
    add(s[0], 1);
    for (rint i = 1; i <= n; i++)
    {
        ans = search(s[i]);
        add(s[i], ans);
    }

    printf("%lld", ans % mod);

    return 0;
}

Part3.线段树

其实线段树这个东西主要还是看理解,所以就只拿一个模板题来练练手就好。

【模板】线段树 2

#include <bits/stdc++.h>

#define int long long
#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;

int n, m, mod, w[N];

struct SegmentTree
{
    int l, r;
    int add, sum, mul;
} t[N << 2];

void push_up(int u)
{
    t[u].sum = t[u << 1].sum + t[u << 1 | 1].sum;
    t[u].sum = t[u].sum % mod;
}

void push_down(int u)
{
    t[u << 1].sum = (t[u << 1].sum * t[u].mul + t[u].add * (t[u << 1].r - t[u << 1].l + 1)) % mod;
    t[u << 1 | 1].sum = (t[u << 1 | 1].sum * t[u].mul + t[u].add * (t[u << 1 | 1].r - t[u << 1 | 1].l + 1)) % mod;
    t[u << 1].add = (t[u << 1].add * t[u].mul + t[u].add) % mod;
    t[u << 1 | 1].add = (t[u << 1 | 1].add * t[u].mul + t[u].add) % mod;
    t[u << 1].mul = t[u << 1].mul * t[u].mul % mod;
    t[u << 1 | 1].mul = t[u << 1 | 1].mul * t[u].mul % mod;
    t[u].add = 0;
    t[u].mul = 1;
}

void build(int p, int l, int r)
{
    if (l == r)
    {
        t[p] = {l, r, 0, w[r], 1};
        return;
    }
    t[p] = {l, r, 0, 0, 1};
    int mid = (l + r) >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    push_up(p);
}

void change(int p, int l, int r, int mul, int add)
{
    if (t[p].l >= l && t[p].r <= r)
    {
        t[p].sum = (t[p].sum * mul + add * (t[p].r - t[p].l + 1)) % mod;
        t[p].add = (t[p].add * mul + add) % mod;
        t[p].mul = t[p].mul * mul % mod;
        return;
    }
    push_down(p);
    int mid = (t[p].l + t[p].r) >> 1;
    if (l <= mid)
    {
        change(p << 1, l, r, mul, add);
    }
    if (r > mid)
    {
        change(p << 1 | 1, l, r, mul, add);
    }
    push_up(p);
}

int query(int p, int l, int r)
{
    if (t[p].l >= l && t[p].r <= r)
    {
        return t[p].sum;
    }
    push_down(p);
    int mid = (t[p].l + t[p].r) >> 1;
    int res = 0;
    if (l <= mid)
    {
        res += query(p << 1, l, r);
        res = res % mod;
    }
    if (r > mid)
    {
        res += query(p << 1 | 1, l, r);
        res = res % mod;
    }
    return res;
}

signed main()
{
    cin >> n >> m >> mod;
    for (rint i = 1; i <= n; i++)
    {
        cin >> w[i];
    }
    build(1, 1, n);
    while (m--)
    {
        int op, l, r, c;
        cin >> op >> l >> r;
        if (op == 1)
        {
            cin >> c;
            change(1, l, r, c, 0); //*c+0
        }
        if (op == 2)
        {
            cin >> c;
            change(1, l, r, 1, c); //*1+c,  *0???
        }
        if (op == 3)
        {
            cout << query(1, l, r) << endl;
        }
    }

    return 0;
}

Part4.并查集

[NOI2015] 程序自动分析

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int M = 3e6 + 5;

int n;
int fa[N];
int b[M];

struct node
{
    int x, y;
    int id;
    bool friend operator<(node a, node b)
    {
        return a.id > b.id;
    }
    /*
    按照约数条件排序 
	把所有 id = 1 的操作放在前面
	再进行 id = 0 的操作
	在进行 id = 1 的操作的时候
	把它约束的两个变量放在同一个集合里面即可
	*/
} a[N];

int get(int x)
{
    return fa[x] == x ? x : fa[x] = get(fa[x]);
}

void clear()
{
    memset(b, 0, sizeof b);
    memset(a, 0, sizeof a);
    memset(fa, 0, sizeof fa);
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        int cnt = 0;
        clear();
        bool flag = true;
        scanf("%d", &n);
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d%d%d", &a[i].x, &a[i].y, &a[i].id);
            b[++cnt] = a[i].x;
            b[++cnt] = a[i].y;
        }
        
        //以下为离散化 
        std::sort(b + 1, b + cnt + 1);
        int tot = std::unique(b + 1, b + cnt + 1) - (b + 1);
        for (rint i = 1; i <= n; i++)
        {
            a[i].x = std::lower_bound(b + 1, b + tot + 1, a[i].x) - (b + 1);
            a[i].y = std::lower_bound(b + 1, b + tot + 1, a[i].y) - (b + 1);
        }
        for (rint i = 0; i <= tot; i++)
        {
            fa[i] = i;
        }
        std::sort(a + 1, a + n + 1);
        for (rint i = 1; i <= n; i++)
        {
            int fx = get(a[i].x);
            int fy = get(a[i].y);
            if (a[i].id == 1)
            {
                fa[fx] = fy;
            }
            if (a[i].id == 0 && fx == fy)
            //如果本来应该不等但是在同一集合 
            {
                puts("NO");
                flag = false;
                break;
            }
        }
        if (flag)
        {
            puts("YES");
        }
    }
    return 0;
}

[JSOI2008] 星球大战

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 4e5 + 5;

int n, m;
int fa[N];
int ans[N];
int num;
int v[N]; //给点标号

struct node
{
    int x, y;
    int id;
    bool friend operator<(node a, node b)
    {
        return a.id < b.id;
    }
} a[N];

int get(int x)
{
    return fa[x] == x ? x : fa[x] = get(fa[x]);
}

void merge(int x, int y)
{
    int fx = get(x);
    int fy = get(y);
    if (fx != fy)
    {
        fa[fx] = fy;
        num--;//集合个数减一 
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    num = n; //假设每个都是独立的集合
    for (rint i = 0; i <= n; i++)
    {
        fa[i] = i;
    }
    for (rint i = 1; i <= m; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        a[i].x = u;
        a[i].y = v;
        a[i].id = 0;
    }
    int k;
    scanf("%d", &k);
    for (rint i = 1; i <= k; i++)
    {
        int att;
        scanf("%d", &att);
        v[att] = k - i + 1;//倒着存 
    }
    for (rint i = 1; i <= m; i++)
    {
        a[i].id = std::max(v[a[i].x], v[a[i].y]);
    }
    std::sort(a + 1, a + m + 1);

    rint j = 1;
    for (rint i = 0; i <= k; i++)
    {
        while (a[j].id == i)
        {
            merge(a[j].x, a[j].y);
            j++;
        }
        ans[i] = num - (k - i);
    }

    for (rint i = k; i >= 0; i--)
    {
        printf("%d\n", ans[i]);
    }

    return 0;
}

[NOI2001] 食物链

开普通并查集的三倍

$x 元素所在集合中所有 \(∈[1,n]\) 的元素都是 \(x\) 元素的同类
\(x+n\) 元素所在集合中所有 \(∈[1,n]\) 的元素都是 \(x\) 元素的天敌
\(x+2n\) 元素所在集合中所有 \(∈[1,n]\) 的元素都是 \(x\) 元素的猎物
\(x+n\) 元素所在的集合中所有 \(∈[1,n]\) 的元素都是 \(x+2n\) 元素的猎物

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 2e5 + 5;

int n, k;
int x, y, z;
int ans;
int fa[N];

int get(int x)
{
    return fa[x] == x ? x : fa[x] = get(fa[x]);
}

void merge(int x, int y)
{
    int fx = get(x);
    int fy = get(y);
    fa[fx] = fy;
}

int main()
{ 
    scanf("%d%d", &n, &k);
    for (rint i = 0; i <= 3 * n; i++)
    {
        fa[i] = i;
    }
    for (rint i = 1; i <= k; i++)
    {
        int op, x, y;
        scanf("%d%d%d", &op, &x, &y);

        if (x > n || y > n)
        //这个动物不存在,显然是假话
        {
            ans++;
            continue;
        }

        if (op == 1)
        {
            if (get(x + n) == get(y) || get(x + 2 * n) == get(y))
            //如果 1 是 2 的天敌或猎物, 显然为谎言
            {
                ans++;
            }
            else
            {
            	//1 的同类是 2 的同类, 1 的猎物是 2 的猎物, 1 的天敌是 2 的天敌
                merge(x, y);
                merge(x + n, y + n);
                merge(x + 2 * n, y + 2 * n);
            }
        }
        
        if (op == 2)
        {
            if (x == y || get(x) == get(y) || get(x + n) == get(y))
            //如果 1 是 2 的同类或猎物, 显然为谎言
            {
                ans++;
            }
            else
            {
            	//1 的同类是 2 的天敌, 1 的猎物是 2 的同类, 1 的天敌是 2 的猎物
                merge(x + 2 * n, y);
                merge(x + n, y + 2 * n);
                merge(x, y + n);
            }
        }
    }

    printf("%d", ans);

    return 0;
}

[NOI2002] 银河英雄传说

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int fa[N];
int dist[N]; //示飞船 i 与其所在列队头的距离
int num[N];  //表示第 i 列的飞船数量

int get(int x)
{
    if (fa[x] == x)
    {
	return x;
    }
    int ans = get(fa[x]);
    dist[x] += dist[fa[x]];
    return fa[x] = ans;
}

int main()
{
    int T;
    scanf("%d", &T);
    for (rint i = 1; i <= N; i++)
    {
    	fa[i] = i;
	dist[i] = 0;
	num[i] = 1;
    }
    while (T--)
    {
	char op;
	int x, y;
	std::cin >> op;
	scanf("%d%d", &x, &y);

	int fx = get(x);
	int fy = get(y);

	if (op == 'M')
	{
	    dist[fx] += num[fy];
	    fa[fx] = fy;
	    num[fy] += num[fx];
	    num[fx] = 0;
	}
	if (op == 'C')
	{
	    if (fx != fy)
	    //若祖先不同, 则不在一列
	    {
		puts("-1");
	    }
	    else
	    {
		printf("%d\n", std::abs(dist[x] - dist[y]) - 1);
	    }
	}
    }
    return 0;
}

14.数论

Part1.质数和约数

判断质数

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

bool isPrime(int n)
{
    if (n < 2)
    {
        return false;
    }
    for (rint i = 2; i <= sqrt(n); i++)
    {
        if (!(n % i))
        {
            return false;
        }
    }
    return true;
}

int main()
{
    int n;
    scanf("%d", &n);
    if (!isPrime(n))
    {
        puts("No");
    }
    else
    {
        puts("Yes");
    }
    return 0;
}

【模板】筛素数

埃氏筛法

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e8 + 5;

int n, q;
int pri[N];
bool vis[N];

void get_primes()
{
    int cnt = 0;
    memset(vis, 0, sizeof vis);
    for (rint i = 2; i <= n; i++)
    {
        if (vis[i])
        {
            continue;
        }
        pri[++cnt] = i;
        for (rint j = 1; j <= n / i; j++)
        {
            vis[i * j] = 1;
        }
    }
}

int main()
{
    scanf("%d%d", &n, &q);
    get_primes();
    while (q--)
    {
        int a;
        scanf("%d", &a);
        printf("%d\n", pri[a]);
    }
    return 0;
}

线性筛法

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e8 + 5;

int n, q;
int pri[N];
int g[N]; //最小质因子

void get_primes()
{
    int cnt = 0;
    memset(g, 0, sizeof g);
    for (rint i = 2; i <= n; i++)
    {
        if (!g[i])
        {
            g[i] = i;
            pri[++cnt] = i;
        }
        for (rint j = 1; j <= cnt; j++)
        {
            if (pri[j] > g[i] || pri[j] > n / i)
            {
                break;
            }
            g[i * pri[j]] = pri[j];
        }
    }
}

int main()
{
    scanf("%d%d", &n, &q);
    get_primes();
    while (q--)
    {
        int a;
        scanf("%d", &a);
        printf("%d\n", pri[a]);
    }
    return 0;
}

筛最小质因子

线性筛法可以同时求出最小质因子和质数,但是埃氏筛不行,所以这里在写一下埃氏筛最小质因子。

void get_prime()
{
    for (rint i = 1; i <= n; i++)
    {
        g[i] = i;
        if (!(g[i] % 2))
        {
            g[i] = 2;
        }
    }

    for (rint i = 3; i < sqrt(n); i += 2)
        for (rint j = i * 2; j <= n; j += i)
            if (g[j] > i)
                g[j] = i;
}

P1835 素数密度

首先有个很显然的做法,直接拿右区间个数减去左区间个数。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 5e7 + 5;

int n;
int pri[N];
bool vis[N];
bool is[N];
int ans[N];

void get_primes()
{
    int cnt = 0;
    memset(vis, 0, sizeof vis);
    for (rint i = 2; i <= n; i++)
    {
        if (vis[i])
        {
            continue;
        }
        pri[++cnt] = i;
        is[pri[cnt]] = 1;
        for (rint j = 1; j <= n / i; j++)
        {
            vis[i * j] = 1;
        }
    }
}

int main()
{
    int l, r;
    scanf("%d%d", &l, &r);
    n = 1e7;
    get_primes();
    for (rint i = 1; i <= 1e7; i++)
    {
        if (is[i])
        {
            ans[i] = ans[i - 1] + 1;
        }
        else
        {
            ans[i] = ans[i - 1];
        }
    }
    printf("%d\n", ans[r] - ans[l - 1]);
    return 0;
}

这个方法会 T,当然也可以直接筛一个区间的:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define int long long
#define endl '\n'

const int N = 5e7 + 5;

int n;
int pri[N];
bool vis[N];
bool is[N];
int ans[N];
int cnt = 0;

void get_primes()
{
    memset(vis, 0, sizeof vis);
    for (rint i = 2; i <= n; i++)
    {
        if (vis[i])
        {
            continue;
        }
        pri[++cnt] = i;
        is[pri[cnt]] = 1;
        for (rint j = 1; j <= n / i; j++)
        {
            vis[i * j] = 1;
        }
    }
}

signed main()
{
    int l, r;
    scanf("%lld%lld", &l, &r);
    if (l == 1)
    {
	l++;
    }
    n = 5e4;
    get_primes();
    for (rint i = 1; i <= cnt; i++)
    {
        for (rint j = std::max(2ll, (l - 1) / pri[i] + 1) * pri[i]; j <= r; j += pri[i])
        {
            if (j >= l)
            {
                ans[j - l] = 1;
            }
        }
    }
    int res = 0;
    for (rint i = 0; i <= r - l; i++)
    {
        if (!ans[i])
        {
            res++;
        }
    }
    printf("%lld\n", res);
    return 0;
}

分解质因数

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

const int N = 5e7 + 5;

int n;
int pri[N], c[N];
int cnt = 0;

void divide()
{
    for (rint i = 2; i <= sqrt(n); i++)
    {
        if (!(n % i))
        {
            pri[++cnt] = i;
            c[cnt] = 0;
            while (!(n % i))
            {
                n /= i;
                c[cnt]++;
            }
        }
    }
    if (n > 1)
    {
        pri[++cnt] = n;
        c[cnt] = 1;
    }
}

signed main()
{
    scanf("%d", &n);
    divide();
    for (rint i = 1; i <= cnt; i++)
    {
        printf("%d %d\n", pri[i], c[i]);
    }
}

AcWing.197 阶乘分解

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e7 + 5;

int n;
int pri[N];
bool vis[N];
int cnt = 0;

void get_primes()
{
    memset(vis, 0, sizeof vis);
    for (rint i = 2; i <= n; i++)
    {
        if (vis[i])
        {
            continue;
        }
        pri[++cnt] = i;
        for (rint j = 1; j <= n / i; j++)
        {
            vis[i * j] = 1;
        }
    }
}

int main()
{
    scanf("%d", &n);
    get_primes();
    for (rint i = 1; i <= cnt; i++)
    {
        int cnt = 0;
        for (rint j = n; j; j /= pri[i])
        {
            cnt += j / pri[i];
        }
        printf("%d %d\n", pri[i], cnt);
    }
    return 0;
}

求约数

试除法求一个数的约数:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n;
int factor[N];
int cnt = 0;

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= sqrt(n); i++)
    {
        if (!(n % i))
        {
            factor[++cnt] = i;
            if (i != n / i)
            {
                factor[++cnt] = n / i;
            }
        }
    }
    for (rint i = 1; i <= cnt; i++)
    {
        printf("%d ", factor[i]);
    }
    return 0;
}

倍数法求 \(n\) 个数的约数:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n;
std::vector<int> factor[N];
int cnt = 0;

int main()
{
    scanf("%d", &n);

    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 1; j <= n / i; j++)
        {
            factor[i * j].push_back(i);
        }
    }
    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 0; j < factor[i].size(); j++)
        {
            printf("%d ", factor[i][j]);
        }
        puts("");
    }

    return 0;
}

[CQOI2007] 余数求和

这个题其实就是在求 \(n∗k−∑_{i=1}^{n}⌊k/i⌋∗i\)

\(⌊k/i⌋\) 用数论分块就行。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'

typedef long long ll;

ll n, k;
ll ans = 0;

int main()
{
    scanf("%lld%lld", &n, &k);
    ans = n * k;

    int l = 1, r = 0;

    while (l <= n)
    {
        if (k / l)
        {
            r = std::min(k / (k / l), n);
        }
        else
        {
            r = n;
        }
        ans -= (k / l) * (r - l + 1) * (l + r) >> 1;
        l = r + 1;
    }

    printf("%lld\n", ans);

    return 0;
}

Part2.欧几里得算法

\(∀a,b ∈N,b≠0,gcd(a,b) = gcd(b, a \% b)\)

可以借此写出最大公约数函数:

int gcd(int a, int b)
{
    return b ? gcd(b, a % b) : a;
}
int lcm(int a, int b) 
{ 
    return a / gcd(a, b) * b; 
}

Part3.欧拉函数

\(n\) 的欧拉函数。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

int n;

int phi(int n)
{
    int ans = n;
    for (rint i = 2; i <= sqrt(n); i++)
    {
        if (!(n % i))
        {
            ans = ans / i * (i - 1);
            while (!(n % i))
            {
                n /= i;
            }
        }
    }
    if (n > 1)
    {
        ans = ans / n * (n - 1);
    }
    return ans;
}

int main()
{
    scanf("%d", &n);
    printf("%d\n", phi(n));
    return 0;
}

\(2 ~ n\) 的欧拉函数。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n;
int phi[N];

void euler(int n)
{
    for (rint i = 2; i <= n; i++)
    {
        phi[i] = i;
    }
    for (rint i = 2; i <= n; i++)
    {
        if (phi[i] == i)
        {
            for (rint j = i; j <= n; j += i)
            {
                phi[j] = phi[j] / i * (i - 1);
            }
        }
    }
}

int main()
{
    scanf("%d", &n);
    euler(n);
    printf("%d\n", phi[n]);
    return 0;
}

[SDOI2008] 仪仗队

显然,答案为 \([φ(2) + φ(3) + ... + φ(n - 1)] * 2 + 3\)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n;
int phi[N];

void euler(int n)
{
    for (rint i = 2; i <= n; i++)
    {
        phi[i] = i;
    }
    for (rint i = 2; i <= n; i++)
    {
        if (phi[i] == i)
        {
            for (rint j = i; j <= n; j += i)
            {
                phi[j] = phi[j] / i * (i - 1);
            }
        }
    }
}

int main()
{
    scanf("%d", &n);
    if (n == 1)
    {
        printf("0");
        return 0;
    }
    euler(n);
    long long ans = 0;
    for (rint i = 2; i < n; i++)
    {
        ans += phi[i];
    }
    printf("%lld", (long long)ans * 2 + 3);
    return 0;
}

Part4.同余

费马小定理

\(p\) 为质数,\(a^p ≡ a\) \((mod\) \(p)\)

欧拉定理

若正整数 \(a, n\) 互质,则 $a^{φ(n)} ≡ 1 $ \((mod\) \(n)\)

推论:若正整数 \(a, n\) 互质,则 \(a^b ≡ a ^ {b \mod φ(n)} (\mod n)\)

【模板】扩展欧拉定理

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'
#define int long long

int a, m;
int b;

int qpow(int a, int b, int p)
{
    a %= p;
    int res = 1;
    for (; b; b >>= 1)
    {
        if (b & 1)
        {
            res = res * a % p;
        }
        a = a * a % p;
    }
    return res;
}

int phi(int n)
{
    int ans = n;
    for (rint i = 2; i <= sqrt(n); i++)
    {
        if (!(n % i))
        {
            ans = ans / i * (i - 1);
            while (!(n % i))
            {
                n /= i;
            }
        }
    }
    if (n > 1)
    {
        ans = ans / n * (n - 1);
    }
    return ans;
}

int read(int mod)
{
    int x = 0;
    bool g = 0;
    char c = getchar();
    while (c < '0' || c > '9')
        c = getchar();
    while (c >= '0' && c <= '9')
    {
        x = (x << 3) + (x << 1) + (c ^ '0');
        if (x >= mod)
        {
            x %= mod;
            g = 1;
        }
        c = getchar();
    }
    if (!g)
        return x;
    return (x + mod);
}

signed main()
{
    scanf("%lld%lld", &a, &m);
    phi(b);
    int times = read(phi(m));
    printf("%lld", qpow(a, times, m));
    return 0;
}

扩展欧几里得算法

对于任意整数 \(a,b\),存在一对整数 \(x,y\) ,满足 \(ax + by = gcd(a, b)\)

int exgcd(int a, int b, int &x, int &y)
{
    if (!b)
    {
        x = 1;
        y = 0;
        return a;
    }
    int d = exgcd(b, a % b, x, y);
    int z = x;
    x = y;
    y = z - y * (a / b);
    return d;
}

[NOIP2012 提高组] 同余方程

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

int a, b;

int exgcd(int a, int b, int &x, int &y)
{
    if (!b)
    {
        x = 1;
        y = 0;
        return a;
    }
    int d = exgcd(b, a % b, x, y);
    int z = x;
    x = y;
    y = z - y * (a / b);
    return d;
}

signed main()
{
    scanf("%lld%lld", &a, &b);
    int x, y;
    int qwq = exgcd(a, b, x, y);
    printf("%lld", (x % b + b) % b);
    return 0;
}

中国剩余定理

\(m_1,m_2,...,m_n\) 是两两互质的整数,\(m = ∏_{i = 1}^{n}m_i, M_i = m / m_i\)\(t_i\) 是线性同余方程 \(M_it_i ≡ 1(mod\) \(m_i)\) 的一个解。

对于任意的 \(n\) 个整数 \(a_1,a_2,...,a_n\) 的方程组:

\(x ≡ a_1 (mod\) \(m_1)\)
\(x ≡ a_2 (mod\) \(m_2)\)
...
\(x ≡ a_1 (mod\) \(m_i)\)

有整数解,解为 \(x = ∑_{i = 1}^{n}a_iM_it_i\)

【模板】中国剩余定理

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 20;

int n;
int m[N], a[N];
int M[N];

int exgcd(int a, int b, int &x, int &y)
{
    if (!b)
    {
        x = 1;
        y = 0;
        return a;
    }
    int d = exgcd(b, a % b, x, y);
    int z = x;
    x = y;
    y = z - y * (a / b);
    return d;
}

int CRT(int m[], int a[], int k)
{
    int x, y;
    int n = 1;
    int ans = 0;
    for (rint i = 0; i < k; i++)
    {
        n = n * m[i];
    }
    for (rint i = 0; i < k; i++)
    {
        M[i] = n / m[i];
        int qwq = exgcd(m[i], M[i], x, y);
        ans = (ans + y * M[i] * a[i]) % n;
    }
    return ans > 0 ? ans : ans + n;
}

signed main()
{
    scanf("%lld", &n);
    for (rint i = 0; i < n; i++)
    {
        scanf("%lld%lld", &m[i], &a[i]);
    }
    printf("%lld", CRT(m, a, n));
    return 0;
}

Part5.乘法逆元

若整数 \(b,m\) 互质,并且 \(b | a\),则存在一个整数 \(x\),使得 \(a/b ≡ a * x(mod\) \(m)\)。称 \(x\)\(b\) 的模 \(m\) 的乘法逆元,计为 \(b^{-1}(mod\) \(m)\)

【模板】乘法逆元

法一:快速幂求解(适用于单次查询)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

int n, p;

int qpow(int a, int b, int p)
{
    a %= p;
    int res = 1;
    for (; b; b >>= 1)
    {
        if (b & 1)
        {
            res = res * a % p;
        }
        a = a * a % p;
    }
    return res;
}

signed main()
{
    scanf("%lld%lld", &n, &p);
    for (rint i = 1; i <= n; i++)
    {
        int x = qpow(i, p - 2, p);
        printf("%lld\n", x);
    }
    return 0;
}

法二:递推(适用于多次查询)

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 3e6 + 5;

int n, p;
int inv[N];

signed main()
{
    scanf("%lld%lld", &n, &p);
    
    inv[0] = 0;
    inv[1] = 1;
    printf("1\n");
	
    for (rint i = 2; i <= n; i++)
    {
	inv[i] = p - (p / i) * inv[p % i] % p;
	printf("%lld\n", inv[i]);
    }
        
    return 0;
}

Part6.组合计数

逆元求组合数:

void init()
{
    fac[0] = inv[0] = 1;
    for (rint i = 1; i <= 200000; i++)
    {
        fac[i] = fac[i - 1] * i % mod;
        inv[i] = qpow(fac[i], mod - 2);
    }
}
int C(int n, int m)
{
    return (inv[m] * inv[n - m] % mod) * fac[n] % mod;
}

Part7.高斯消元

消元法是将方程组中的一方程的未知数用含有另一未知数的代数式表示,并将其代入到另一方程中,这就消去了一未知数,得到一解;或将方程组中的一方程倍乘某个常数加到另外一方程中去,也可达到消去一未知数的目的。消元法主要用于二元一次方程组的求解。

核心:
1.两方程互换,解不变;
2.一方程乘以非零数k,解不变;
3.一方程乘以数k加上另一方程,解不变.

【模板】高斯消元法

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;
const double eps = 1e-6;

int n;
double a[N][N];
bool flag;

void gauss()
{
    for (rint i = 1; i <= n; i++)
    {
        int r = i;
        for (rint j = i + 1; j <= n; j++)
        {
            if (fabs(a[r][i]) < fabs(a[j][i]))
            {
                r = j;
            }
        }
        if (r != i)
        {
            for (rint j = 1; j <= n + 1; j++)
            {
                std::swap(a[i][j], a[r][j]);
            }
        }
        if (fabs(a[i][i]) < eps)
        {
            flag = 1;
            return;
        }
        for (rint j = 1; j <= n; j++)
        {
            if (j != i)
            {
                double tmp = a[j][i] / a[i][i];
                for (rint k = i + 1; k <= n + 1; k++)
                {
                    a[j][k] -= a[i][k] * tmp;
                }
            }
        }
    }
    for (rint i = 1; i <= n; i++)
    {
        a[i][n + 1] /= a[i][i];
    }
}

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
        for (rint j = 1; j <= n + 1; j++)
        {
            scanf("%lf", &a[i][j]);
        }
    }
    gauss();
    if (flag)
    {
        return puts("No Solution"), 0;
    }
    for (rint i = 1; i <= n; i++)
    {
        printf("%.2lf\n", a[i][n + 1]);
    }
    return 0;
}

Part8.Lucas 定理

\(C^m _n \mod p = C^{m/p} _{n/p}*C^{m \mod p} _{n \mod p} \mod p\)

【模板】卢卡斯定理

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 1e5 + 5;

int n, m, p;
int fac[N], inv[N];

int qpow(int a, int b)
{
    int res = 1;
    for (; b; b >>= 1)
    {
	if (b & 1)
	{
	    res = res * a % p;
	}
    	a = a * a % p;
    }
    return res;
}

void init()
{
    fac[0] = inv[0] = 1;
    for (rint i = 1; i <= 100000; i++)
    {
	fac[i] = fac[i - 1] * i % p;
	inv[i] = qpow(fac[i], p - 2);
    }
}

int C(int n, int m)
{
    if (m > n) return 0;
    return (inv[m] * inv[n - m] % p) * fac[n] % p;
}

int lucas(int n, int m, int p)
{
    if (!m)
    {
	return 1;
    }
    return C(n % p, m % p) * lucas(n / p, m / p, p) % p;
}

signed main()
{
    int T;
    scanf("%lld", &T);
    while (T--)
    {
	scanf("%lld%lld%lld", &n, &m, &p);
	init();
	n += m;
	printf("%lld\n", (lucas(n, m, p) + p) % p);
    }

    return 0;
}

Part9.题目训练

见博客文章CSP-S 考前数学练习

15.优雅的 dfs

Part1.剪枝

AcWing.165 小猫爬山

有一个类似于记忆化的小剪枝,如果当前答案已经不如之前的最优解,那就没有继续搜下去的必要了,终止搜索。

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 20;
const int inf = 0x3f3f3f3f;

int n, m;
int a[N], p[N];
int res = inf;

bool cmp(int a, int b)
{
    return a > b;
}

void dfs(int x, int cnt)
{
    if (cnt >= res)
    {
        return;
    }
    if (x == n + 1)
    {
        res = std::min(res, cnt);
        return;
    }
    for (rint i = 1; i <= cnt; i++)
    {
        if (a[x] <= m - p[i])
        {
            p[i] += a[x];
            dfs(x + 1, cnt);
            p[i] -= a[x];
        }
    }
    p[++cnt] = a[x];
    dfs(x + 1, cnt);
    p[cnt] = 0;
}

signed main()
{
    scanf("%lld%lld", &n, &m);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
    }
    std::sort(a + 1, a + n + 1, cmp);
    dfs(1, 0);
    printf("%lld", res);
    return 0;
}

P1120 小木棍

注释在代码里。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e2 + 5;

int n;
int a[N];
bool v[N];
int len, cnt;
int sum;
int maxx;

bool cmp(int a, int b)
{
    return a > b;
}

//last 表示上一个木棍编号
//x 动态的 cnt (计数)
//now 当前长度 
bool dfs(int x, int now, int last)
{
    if (x > cnt)
    {
        return 1;
    }
    if (now == len)
    {
        return dfs(x + 1, 0, 1);
    }

    int fail = 0; 
    //剪枝1, fail != a[i] 可以省去无效搜索 
    for (rint i = last; i <= n; i++)
    //剪枝2, 使加入小木棍的长度递减 
    {
        if (!v[i] && now + a[i] <= len && fail != a[i])
        {
            v[i] = 1;//剪枝3, 打上标记 
            if (dfs(x, now + a[i], i + 1)) 
            {
                return 1;
            }
            fail = a[i];
            v[i] = 0;
            if (now == 0 || now + a[i] == len)
            {
                return 0;
            }
        }
    }

    return 0;
}

int main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        if (a[i] > 50)
        {
            n--;
            i--;
            continue;
        }
        sum += a[i];
        maxx = std::max(maxx, a[i]);
    }

    std::sort(a + 1, a + n + 1, cmp);
    //剪枝4, 优化搜索顺序 

    for (len = maxx; len <= sum; len++)
    {
        if (sum % len)
        {
            continue;
        }
        cnt = sum / len;
        memset(v, 0, sizeof v);
        if (dfs(1, 0, 1))
        {
            break;
        }
    }

    printf("%d", len);

    return 0;
}

SP338 Roads

非常简单的一个图论问题,只需要加一个剪枝即可:

如果花费大于要求或这长度大于已知的,直接结束。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e6 + 5;
const int M = 2e6 + 5;

int h[N], ne[M], e[M], w[M], c[M];
int idx;
int sum;
bool v[M];
int n, m, k;

void add(int a, int b, int C, int d)
{
    e[++idx] = b, ne[idx] = h[a], w[idx] = C, c[idx] = d, h[a] = idx;
}

void dfs(int x, int len, int cost)
{
    if (cost > k || len > sum) //剪枝
    {
        return;
    }
    if (x == n)
    {
        sum = std::min(sum, len);
    }

    for (rint i = h[x]; i; i = ne[i])
    {
        int y = e[i];
        if (v[y])
        {
            continue;
        }
        v[y] = 1;
        dfs(y, len + w[i], cost + c[i]);
        v[y] = 0;
    }
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        idx = 0;
        sum = 0x3f3f3f3f;
        scanf("%d%d%d", &k, &n, &m);
        memset(v, 0, sizeof v);
        memset(h, 0, sizeof h);
        for (rint i = 1; i <= m; i++)
        {
            int a, b, c, d;
            scanf("%d%d%d%d", &a, &b, &c, &d);
            add(a, b, c, d);
        }
        dfs(1, 0, 0);
        if (sum == 0x3f3f3f3f)
        {
            puts("-1");
        }
        else
        {
            printf("%d\n", sum);
        }
    }

    return 0;
}

Part2.迭代加深

迭代加深就是防止答案在比较浅的层时,搜索到很深而找不到答案的情况,通过设置一个最大迭代深度来解决这个情况

UVA529 加成序列

一道 dfs,迭代加深

我们可以很快的猜出来最终 \(m\) 的长度必然是小于 \(10\) 的。

而这种浅深度的问题正好适用于迭代加深。

之后考虑剪枝

优化搜索顺序 : 我们要让序列中的数字迅速地逼近 \(n\),自然是要 \(i\)\(j\) 从大到小枚举,且 \(j<=i\)

排除等效冗余 : 我们发现,对于不同的 \(i\)\(j\) ,他们的 \(a[i]+a[j]\) 有一定的可能性会相同

最优化剪枝 :后面每一项最多是前一项的 \(2\)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 1e4 + 5;

int n, path[N];

bool dfs(int u, int k)
// u is the depth of now
// k is the depth of MAX
{
    if (path[u - 1] > n)
    {
        return false;
    }
    if (u == k)
    {
        return path[u - 1] == n;
    }
    if (path[u - 1] * ((long long)1 << (k - u)) < n)//最优化剪枝
    {
        return false;
    }
    bool st[N] = {0};//通过 bool数组排除等效冗余
    for (rint i = u - 1; i >= 0; i--)
    {
        for (rint j = i; j >= 0; j--)//搜索顺序
        {
            int s = path[i] + path[j];
            if (s > n || s <= path[u - 1] || st[s])
            {
                continue;
            }
            st[s] = true;
            path[u] = s;
            if (dfs(u + 1, k))
            {
                return true;
            }
        }
    }
    return false;
}

signed main()
{
    path[0] = 1;
    while (scanf("%d", &n) and n != 0)
    {
        int depth = 1;
        while (!dfs(1, depth))// 不断扩大范围
        {
            depth++;
        }
        for (rint i = 0; i < depth - 1; i++)
        {
            std::cout << path[i] << " ";
        }
        std::cout << path[depth - 1];
	//UVA 输出不能有多余空格
        puts("");
    }

    return 0;
}

Part3.折半搜索

一般想到折半搜索的时候,都是朴素 dfs 时空复杂度为 \(O(k^{q})\) , \(k\) 较小, \(q\) 较大的时候使用。

AcWing.171 送礼物

先把前边都搜出来,再把后边的也搜出来了,之后二分查一下就行了。这个也是大部分折半搜索的常规套路。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'
#define int long long

const int N = 1 << 24;

int n, m, k;
int a[N];
int w[N];
int cnt, ans;

bool cmp(int a, int b)
{
    return a > b;
}

void dfs1(int x, int sum)
{
    if (x > k)
    {
	w[++cnt] = sum;
	return;
    }
    if (a[x] <= m - sum)
    {
	dfs1(x + 1, sum + a[x]);
    }
    dfs1(x + 1, sum);
}

void dfs2(int x, int sum)
{
    if (x > n)
    {
	int l = std::upper_bound(w + 1, w + cnt + 1, m - sum) - (w + 1);
	if (w[l] <= m - sum)
	{
	    ans = std::max(ans, w[l] + sum);
	}
	return;
    }
    if (a[x] <= m - sum)
    {
	dfs2(x + 1, sum + a[x]);
    }
    dfs2(x + 1, sum);
}

signed main()
{
    scanf("%lld%lld", &m, &n);

    for (rint i = 1; i <= n; i++)
    {
	scanf("%lld", &a[i]);
    }

    std::sort(a + 1, a + n + 1, cmp);
    k = n / 2;
    dfs1(1, 0);

    std::sort(w + 1, w + cnt + 1);
    cnt = std::unique(w + 1, w + cnt + 1) - w - 1;

    dfs2(k + 1, 0);

    printf("%lld\n", ans);
    return 0;
}

[CEOI2015 Day2] 世界冰球锦标赛

#include <iostream>
#include <cstdio>
#include <algorithm>

#define rint register int
#define endl '\n'
#define int long long

const int N = 2e6 + 5;

int n, m, k;
int a[N], b[N];
int w[N];
int ans;
int cnt1, cnt2;

void dfs1(int x, int sum)
{
    if (sum > m)
    {
	return;
    }
    if (x > k)
    {
	a[++cnt1] = sum;
	return;
    }
    dfs1(x + 1, sum + w[x]);
    dfs1(x + 1, sum);
}

void dfs2(int x, int sum)
{
    if (sum > m)
    {
	return;
    }
    if (x > n)
    {
        b[++cnt2] = sum;
	return;
    }
    dfs2(x + 1, sum + w[x]);
    dfs2(x + 1, sum);
}

signed main()
{
    scanf("%lld%lld", &n, &m);
    for (rint i = 1; i <= n; i++)
    {
	scanf("%lld", &w[i]);
    }
    k = n / 2;
    dfs1(1, 0);
    dfs2(k + 1, 0);
    std::sort(b + 1, b + cnt2 + 1);
    for (rint i = 1; i <= cnt1; i++)
    {
    	ans += std::upper_bound(b + 1, b + cnt2 + 1, m - a[i]) - (b + 1);
    }
    printf("%lld", ans);
    return 0;
}

[USACO]Balanced Cow Subsets

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>
#include <vector>

#define rint register int
#define endl '\n'

const int N = 21;
const int M = 1 << 20;

int n, k;
int a[N], ans;
std::map<int, std::vector<int>> s;
bool v[M];

void dfs1(int x, int sum, int u)
// u 是状态
{
    if (x > k)
    {
	s[sum].push_back(u);
	return;
    }
    dfs1(x + 1, sum, u);
    dfs1(x + 1, sum + a[x], u | (1 << (x - 1)));
    dfs1(x + 1, sum - a[x], u | (1 << (x - 1)));
}

void dfs2(int x, int sum, int u)
{
    if (x > n)
    {
	if (!s.count(sum))
	{
	    return;
	}
	int len = s[sum].size();
	for (rint i = 0; i < len; i++)
	{
	    v[u | s[sum][i]] = 1;
	}
	return;
    }
    dfs2(x + 1, sum, u);
    dfs2(x + 1, sum + a[x], u | (1 << (x - 1)));
    dfs2(x + 1, sum - a[x], u | (1 << (x - 1)));
}

signed main()
{
    scanf("%d", &n);
    for (rint i = 1; i <= n; i++)
    {
    	scanf("%d", &a[i]);
    }
    k = n / 2;
    dfs1(1, 0, 0);
    dfs2(k + 1, 0, 0);
    for (rint i = 1; i < 1 << n; i++)
    {
    	if (v[i])
    	{
    	    ans++;
	}
    }
    printf("%d", ans);
    return 0;
}

Part4.IDA*

AcWing.180 排书

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

using namespace std;

const int N = 20;

int n, a[N];
int w[5][N];
//回复现场使用,前一位表示 depth,保证各层互不影响

int evaluate()//估价函数
{
    int tot = 0;
    for (rint i = 1; i < n; i++)
    {
        if (a[i] + 1 != a[i + 1])
        {
            tot++;
        }
    }
    return (tot + 2) / 3;
    /*
    这个怎么理解呢?
    首先,判断一共有多少个"不连接"的断点,
    如 1 3 4 5, 3 就是断电
    然后一次操作最好的情况是可以回复三个断点
    所以个数除以 3 就是估价函数的值了
    向上取整
    */
}

bool check()//检查序列是否已经有序
{
    for (rint i = 1; i < n; i++)
    {
        if (a[i] + 1 != a[i + 1])
        {
            return 0;
        }
    }
    return 1;
}

bool dfs(int depth, int max_depth)
//depth 当前深度
//max_depth 最大深度
{
    if (depth + evaluate() > max_depth)
    {
        return 0;
    }
    if (check())
    {
        return 1;
    }

    for (rint len = 1; len <= n; len++)
    {
        for (rint l = 1; l <= n - len + 1; l++)
        {
            int r = l + len - 1;
            for (rint k = r + 1; k <= n; k++)
            {
                int y = l;
                memcpy(w[depth + 1], a, sizeof a);

                //下边是移动序列部分
                for (rint x = r + 1; x <= k; x++, y++)
                {
                    a[y] = w[depth + 1][x];
                }
                for (rint x = l; x <= r; x++, y++)
                {
                    a[y] = w[depth + 1][x];
                }

                //递归
                if (dfs(depth + 1, max_depth))
                {
                    return 1;
                }

                //恢复现场
                memcpy(a, w[depth + 1], sizeof a);
            }
        }
    }
    return false;
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n);
        for (rint i = 1; i <= n; i++)
        {
            scanf("%d", &a[i]);
        }

        int depth = 0;
        while (depth < 5 && !dfs(0, depth))
            depth++;

        if (depth >= 5)
            puts("5 or more");
        else
            printf("%d\n", depth);
    }

    return 0;
}

[SCOI2005] 骑士精神

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define rint register int
#define endl '\n'

const int N = 10;

const int b[10][10] = {
    {0, 0, 0, 0, 0, 0},
    {0, 1, 1, 1, 1, 1},
    {0, 0, 1, 1, 1, 1},
    {0, 0, 0, 2, 1, 1},
    {0, 0, 0, 0, 0, 1},
    {0, 0, 0, 0, 0, 0},
}; //最理想的样子

const int dx[] = {1, -1, -1, 1, 2, -2, -2, 2};
const int dy[] = {2, -2, 2, -2, 1, -1, 1, -1};

int a[N][N];
bool flag;

int evaluate()
{
    int tot = 0;
    for (rint i = 1; i <= 5; i++)
    {
        for (rint j = 1; j <= 5; j++)
        {
            if (a[i][j] != b[i][j])
            //判断有几个棋子不对 
            {
                tot++;
            }
        }
    }
    return tot;
}

bool check(int x, int y)
{
    if (x < 1 || x > 5 || y < 1 || y > 5)
    {
        return 0;
    }
    return 1;
}

void dfs(int now, int sum, int x, int y)
/*
now 表示当前步数
sum 表示总步数
x, y 表示空格坐标
*/
{
    if (now == sum && !evaluate())
    {
        flag = 1;
        return ;
    }
    for (rint i = 0; i < 8; i++)
    {
        int xx = x + dx[i];
        int yy = y + dy[i];//计算当前坐标 
        if (!check(xx, yy))
        {
            continue;
        }
        std::swap(a[xx][yy], a[x][y]);
        if (evaluate() + now - 1 < sum)
        //最优性剪枝:当前的步数 + 差异 > 限制步数 
        {
            dfs(now + 1, sum, xx, yy);
        }
        std::swap(a[xx][yy], a[x][y]);
    }
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T--)
    {
        int s = 0, t = 0;
        flag = 0;
        memset(a, 0, sizeof a);

        for (rint i = 1; i <= 5; i++)
        {
            for (rint j = 1; j <= 5; j++)
            {
                char ch;
                std::cin >> ch;
                if (ch == '*')
                {
                    a[i][j] = 2;
                    s = i;
                    t = j;
                }
                else
                {
                    a[i][j] = ch - 48;
                }
            }
        }

        if (!evaluate())
        {
            puts("0");
            continue;
        }  

        for (rint i = 1; i <= 15; i++)
        {
            dfs(0, i, s, t);
            if (flag == 1)
            {
                printf("%d\n", i);
                break;
            }
        }

        if (!flag)
        {
            puts("-1");
        }
    }

    return 0;
}
posted @ 2022-09-17 16:22  PassName  阅读(189)  评论(0编辑  收藏  举报