AtCoder Grand Contest 060 E Number of Cycles
讲个笑话,一年前做过,今天模拟赛出了,但是完全不记得,然后想了一种完全不同的方法,我真抽象。
首先考虑什么时候有解。显然 \(m = n + f(a)\) 的时候有解,令 \(b_i = i, c_i = a_i\) 即可。然后考虑任意交换一对 \((i, j)\),此时 \((b_i, b_j), (c_i, c_j)\) 都会被交换。所以 \(f(b), f(c)\) 的变化量均为 \(\pm 1\)。所以 \(f(b) + f(c)\) 的奇偶性确定,若 \(m\) 和 \(n + f(a)\) 的奇偶性不同就无解。
然后发现 \(m \le n + f(a)\) 时才可能有解。考虑 \(b_i = i, c_i = a_i\) 的情形,显然交换任意一对后 \(f(b)\) 都会 \(-1\),所以交换后一定不优。
然后发现这两个条件是充要条件,下面考虑构造性证明。
先构造出 \(f(b) + f(c) = n + f(a)\) 的方案,也就是 \(b_i = i, c_i = a_i\)。考虑先从 \(c\) 入手,以任意顺序扫 \(c\) 的连通块,并且与上一个连通块合并,也就是任意交换这个连通块和上一个连通块的一对元素即可。显然它们在 \(b\) 中属于的连通块也不同。所以 \(f(b), f(c)\) 经过这样的一次操作后都会 \(-1\),所以 \(f(b) + f(c)\) 经过这样一次操作后会 \(-2\)。
扫完之后 \(f(c) = 1\),但是 \(f(b)\) 可能 \(> 1\)。我们希望当 \(f(b) > 2\) 时让 \(f(b)\) 减小 \(2\),同时保持 \(f(c) = 1\)。
不妨考虑找到一个 \(i\) 使得 \((i, c_i)\) 在 \(b\) 中不在同一个连通块(容易发现一定能找到,因为不能找到就说明 \(f(b) = 1\),可以直接退出)。此时可以直接交换 \((i, c_i)\),这样 \(f(b)\) 会 \(-1\),但是 \(f(c)\) 会 \(+1\)。
我们发现一个非常好的事情,就是 \(c_i\) 在 \(c\) 中变成了一个自环,同时所有其他的点组成了一个环。也就是说我们让 \(c_i\) 和任意一个其他点交换,\(f(c)\) 都会变回 \(1\)。那么就让 \(c_i\) 和任意一个在 \(b\) 中和它不在同一个连通块的点 \(j\) 交换,这样 \(f(b)\) 会 \(-1\)。
所以经过上面的过程,\(f(b) + f(c)\) 会 \(-2\)。
考虑快速维护这个构造的过程。可以使用一个类似栈的东西维护全部 \((i, c_i)\) 可能在 \(b\) 中在同一个连通块的点 \(i\)。找的时候直接取栈顶,在一个连通块就弹出。交换一对 \((i, j)\) 就把 \(i, j\) 都加入栈即可。
然后对一个点 \(x\) 找在 \(b\) 中和它不在一个连通块的任意一个点,可以维护任意两个连通块的祖先,其中必然有一个不是 \(x\) 的祖先。
所以时间复杂度就是 \(O(n \log n)\),\(\log n\) 是并查集复杂度,可以粗略地看成线性。实际跑得很快。
code
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 200100;
int n, m, a[maxn], b[maxn], stk[maxn * 5], top, X, Y, fa[maxn];
bool vis[maxn];
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
vis[x] = 1;
while (vis[X]) {
++X;
}
while (vis[Y] || X == Y) {
++Y;
}
}
}
void solve() {
scanf("%d%d", &n, &m);
X = Y = 0;
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
b[i] = fa[i] = i;
vis[i] = 0;
}
for (int i = 1; i <= n; ++i) {
merge(i, a[i]);
}
int t = n;
for (int i = 1; i <= n; ++i) {
t += (fa[i] == i);
}
if (m > t || ((m + t) & 1)) {
puts("No");
return;
}
puts("Yes");
if (m == t) {
for (int i = 1; i <= n; ++i) {
printf("%d ", i);
}
putchar('\n');
return;
}
int lst = -1;
for (int i = 1; i <= n; ++i) {
if (fa[i] == i) {
if (lst != -1) {
swap(a[lst], a[i]);
swap(b[lst], b[i]);
t -= 2;
}
lst = i;
}
if (m == t) {
for (int i = 1; i <= n; ++i) {
printf("%d ", b[i]);
}
putchar('\n');
return;
}
}
for (int i = 1; i <= n; ++i) {
fa[i] = i;
vis[i] = 0;
stk[++top] = i;
}
X = 1;
Y = 2;
for (int i = 1; i <= n; ++i) {
merge(i, b[i]);
}
while (1) {
while (1) {
int i = stk[top];
int j = a[i];
if (find(i) == find(j)) {
--top;
continue;
}
merge(i, j);
stk[++top] = j;
swap(a[i], a[j]);
swap(b[i], b[j]);
int k = (find(i) == X ? Y : X);
stk[++top] = k;
merge(j, k);
swap(a[j], a[k]);
swap(b[j], b[k]);
break;
}
t -= 2;
if (m == t) {
for (int i = 1; i <= n; ++i) {
printf("%d ", b[i]);
}
putchar('\n');
return;
}
}
}
int main() {
int T = 1;
scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}