Loading

【题解】LOJ 516 - DP 一般看规律

题目大意

题目链接

给定一个长度为 \(n\) 的序列和 \(m\) 次操作,每次操作可以把序列所有值 \(x\) 改成值 \(y\) 。请在每次操作后输出序列中相同的值之间的最短距离。例如序列 \([1, 2, 1, 4, 2]\) ,其相同的值之间的最短距离为下标为 \(1\)\(3\) 的值之间的距离,也就是 \(3\) 。如果序列中没有相同的值,输出 2147483647

解题思路

这道题属于求解最值的问题。根据经验,我们可以猜想这道题可能是用 \(dp\) ,贪心或者二分等算法求解。可以发现,答案一定单调不增的,下面给出不严谨证明:

因为将值 \(x\) 改成值 \(y\) 以后,原本值 \(x\) 贡献的答案并不会改变;此时值 \(x\) 反而可能与值 \(y\) 产生更小的答案,即使从最坏的情况考虑,答案也一定不会增加。

由此,我们可以想出一个时间复杂度为 \(O(n ^ 2)\) 的乱搞做法:每次操作都扫描一遍数组,将值 \(x\) 改成值 \(y\) 。此时值 \(y\) 中包含了原来的值 \(x\) ,直接求出修改后值 \(y\) 两两间隔的最小值,并与原来的答案比较即可。

遗憾的是,此题的数据并没有 \(n, m \leq 1000\) 的情况。根据笔者珍贵的考场经验,直接对数组进行修改会得到 \(0\) 分的好成绩!因此,我们考虑其他做法。显然 \(dp\) 和贪心不可行,我们考虑使用技巧来优化修改值 \(x\) 和求解答案的过程。

每次修改以后查询答案,我们只需要查询 \(y\) 出现过的所有下标。因此,我们可以将所有值为 \(y\) 的下标都存在一个集合中。将值 \(x\) 修改成值 \(y\) ,相当于把 \(x\) 的下标集合合并到 \(y\) 的下标集合。发散思维,得到本题的正解做法:启发式合并 。估算时间复杂度,单次修改 \(O(logn)\) ,显然可以跑进 \(1\) 秒。

每次合并后必须要在 \(O(logn)\) 的时间复杂度内查找出集合 \(y\) 中与集合 \(x\) 中的某个数最接近的值。因此,我们选择用 set 数组存储,直接二分查找即可。因为值域太大,我们还需要进行 离散化 处理。set[i] 表示离散化后排名为 i 的数在数组中出现的所有位置。

id[i] 表示排名为 \(i\) 的数对应的集合下标。每次启发式合并两个不同的集合 \(x, y\) 时,不妨设 \(x\) 的元素个数较少,否则直接交换 id[x]id[y] 。此时我们需要更新答案:可能与当前的下标 \(v\) 更新答案的下标只有最接近 \(v\) 的前后两个下标。我们考虑使用 lower_bound ,因为下标不会重复,所以我们可以直接 lower_bound 求出它后面最接近它的下标。lower_bound 的前一个元素就是 \(v\) 前面最接近 \(v\) 的下标,分别更新答案即可。更新完答案,记得将集合 \(x\) 加入集合 \(y\) ,并清空集合 \(x\) 中的元素。

最后,对 set 进行 lower_bound 一定要使用 s.lower_bound(val) 而非 lower_bound(s.begin(), s.end(), val) 。否则会得到光荣的 \(TLE\ 48\) 分。

参考代码

#include <cstdio>
#include <set>
#include <algorithm>
using namespace std;
 
const int maxn = 3e5 + 5;
const int inf = 2147483647;
 
int n, m, ans;
int cnt, tot;
int a[maxn], b[maxn], idx[maxn];
int x[maxn], y[maxn], id[maxn];
set<int> s[maxn];
 
int read()
{
    int res = 0, flag = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9')
    {
        if (ch == '-')
            flag = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        res = res * 10 + ch - '0';
        ch = getchar();
    }
    return res * flag;
}
 
void write(int x)
{
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
 
void merge(int x, int y)
{
    if (x == y)
        return;
    if (s[id[x]].size() > s[id[y]].size())
        swap(id[x], id[y]);
    set<int>::iterator i, j;
    for (i = s[id[x]].begin(); i != s[id[x]].end(); i++)
    {
        j = s[id[y]].lower_bound(*i);
        if (j != s[id[y]].end())
            ans = min(ans, *j - *i);
        if (j != s[id[y]].begin())
        {
            j--;
            ans = min(ans, *i - *j);
        }
        s[id[y]].insert(*i);
    }
    s[id[x]].clear();
}
 
int main()
{
    ans = inf;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        b[++cnt] = a[i] = read();
    for (int i = 1; i <= m; i++)
    {
        b[++cnt] = x[i] = read();
        b[++cnt] = y[i] = read();
    }
    sort(b + 1, b + cnt + 1);
    tot = unique(b + 1, b + cnt + 1) - b - 1;
    for (int i = 1; i <= cnt; i++)
        id[i] = i;
    for (int i = 1; i <= n; i++)
    {
        a[i] = lower_bound(b + 1, b + tot + 1, a[i]) - b;
        s[a[i]].insert(i);
        if (idx[a[i]])
            ans = min(ans, i - idx[a[i]]);
        idx[a[i]] = i;
    }
    for (int i = 1; i <= m; i++)
    {
        x[i] = lower_bound(b + 1, b + tot + 1, x[i]) - b;
        y[i] = lower_bound(b + 1, b + tot + 1, y[i]) - b;
        merge(x[i], y[i]);
        write(ans), puts("");
    }
    return 0;
}
posted @ 2021-07-24 23:32  kymru  阅读(77)  评论(0编辑  收藏  举报