【题解】BZOJ4278 - [ONTAK2015] Tasowanie
题目大意
给定两个数字串 \(A, B\),通过将 \(A\) 和 \(B\) 进行二路归并得到一个新的数字串 \(T\),请找到字典序最小的 \(T\)。二路归并可大致理解为有两指针初始时分别在 \(A, B\) 开头向后扫描,每次决定在数字串 \(T\) 中加入 \(a_i\) 还是 \(b_j\),并令相应指针后移。
\(1 \leq n, m \leq 200000, 1 \leq A_i, B_i \leq 1000\)
解题思路
如果按照题意模拟,不难想到 贪心 的思路:每次判断 \(a_i\) 和 \(b_j\) 的字典序关系,将小的一个加入数字串 \(T\)。但是当它们字典序相等时,我们需要讨论加入的字符。不难发现如果其中一个字符到该字符串末尾组成的子串如果比另一方更小,那么我们先加入这个字符一定是更优的。具体原因不难理解:字典序是从前向后比较的,所以我们考虑把字典序较小的元素尽快加入数字串 \(T\)。具体证明见 这篇博客(如有侵权请联系 我)。
显然我们需要判断 \(A, B\) 的后缀之间的字典序关系。如果这种比较仅在一个字符串之间,那么可以直接使用 后缀数组。因此我们可以考虑 构造 出一个新的字符串,使得可以通过求出该字符串的后缀数组来解决问题。我们可以构造一个新的字符串 \(S = A + \infty + B + \infty\),然后求出字符串 \(S\) 的后缀数组。我们维护双指针,分别指向 \(S\) 的第一个字符和 \(B\) 的首字符在 \(S\) 中对应的位置,记为 \(l, r\)。每次比较 \(rk_l\) 和 \(rk_r\),将较小的一方加入数字串 \(T\) 并移动指针即可。
算法的正确性显然。对于字符串 \(S\) 的后缀 \(l\) 和后缀 \(r\),如果后缀 \(l\) 包含后缀 \(r\),那么显然此时取任意一个字符都满足贪心策略;如果后缀 \(l\) 不包含后缀 \(r\),那么它们的字典序关系一定不同,取较小值即可。实际上后缀 \(l\) 不在字符串 \(A\) 中出现的部分可以忽略不计,因为当包含时这一部分没有贡献,当不包含时我们一定已经在此前的部分判断完了字典序。所以贪心策略仍然成立。
参考代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 4e5 + 5;
const int maxm = 1e6 + 5;
int n, m, k, len;
int s[maxn], cnt[maxn], id[maxn];
int sa[maxn], rk[maxm], oldrk[maxm];
int read()
{
int res = 0, flag = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
flag = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res * flag;
}
void write(int x)
{
if (x < 0)
{
x = -x;
putchar('-');
}
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
bool cmp(int x, int y, int w)
{
return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];
}
int main()
{
int p = 0;
n = read();
for (int i = 1; i <= n; i++)
s[i] = read();
s[n + 1] = 1001;
scanf("%d", &k);
for (int i = 1; i <= k; i++)
s[n + i + 1] = read();
s[n + k + 2] = 1001;
len = n + k + 2, m = max(n, 1001);
for (int i = 1; i <= len; i++)
cnt[rk[i] = s[i]]++;
for (int i = 1; i <= m; i++)
cnt[i] += cnt[i - 1];
for (int i = len; i >= 1; i--)
sa[cnt[rk[i]]--] = i;
for (int j = 1; ; j <<= 1, m = p, p = 0)
{
for (int i = len; i > len - j; i--)
id[++p] = i;
for (int i = 1; i <= len; i++)
if (sa[i] > j)
id[++p] = sa[i] - j;
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= len; i++)
cnt[rk[id[i]]]++;
for (int i = 1; i <= m; i++)
cnt[i] += cnt[i - 1];
for (int i = len; i >= 1; i--)
sa[cnt[rk[id[i]]]--] = id[i];
memcpy(oldrk, rk, sizeof(rk)), p = 0;
for (int i = 1; i <= len; i++)
rk[sa[i]] = cmp(sa[i], sa[i - 1], j) ? p : ++p;
if (p == len)
{
for (int i = 1; i <= len; i++)
sa[rk[i]] = i;
break;
}
}
int l = 1, r = 1;
while (l <= n || r <= k)
{
if (rk[l] < rk[n + 1 + r])
write(s[l++]), putchar(' ');
else
{
write(s[n + r + 1]);
putchar(' '), r++;
}
}
return 0;
}