[Luogu] P5512 棋盘问题(2)【加强】
Description
在\(N \times N\)的棋盘上\((1≤N≤10)\),填入\(1,2,…,N^2\)共\(N^2\)个数,使得任意两个相邻的数之和为素数。约定:左上角的格子里必须填数字\(1\)。如有多种解,则输出第一行、第一列之和为最小的排列方案;若无解,则输出 NO
。
Solution
搜索好题。
纯爆搜是肯定过不去的,考虑如何优化。
可以先把\(2N^2\)范围内的质数算出来,然后预处理出每个数可能相邻的数有什么。因为要第一行、第一列之和为最小,可以先把第一行、第一列尽量搜小的搜出来,如果当前和比最小的可行解大就剪掉,然后再搜剩下的。但这样还是不够。
考虑第一行、第一列之和的理论下界。如果这\(2N-1\)个位置分别填\(1\sim{2N-1}\),则理论最小为\(N(2N-1)\)。但其实这只是对\(N\)为奇数的,偶数的下界其实是\(N(2N-1)+1\)。这是因为两个相邻的位置只能填一奇一偶,然后\(1\)已经定死了,所以每个位置填的数的奇偶性已经确定了。当\(N\)为偶数时,下界如果要达到\(N(2N-1)\),则必须依次填入\(1\sim{2N-1}\),有\(N\)个奇数,\(N-1\)个偶数。但因为奇偶性限定,第一行、第一列必须填入\(N\)个偶数,\(N-1\)个奇数,矛盾。所以偶数的理论下界至少是\(N(2N-1)+1\)。
知道下界有什么用呢?考虑当\(N\)比较大的时候,有非常多种填数方案,我们总能找到一种,满足第一行、第一列之和达到理论下界。所以只要我们能找到一种方案达到下界,它就一定是最优的,直接输出即可。否则如果\(N\)比较小,找不到方案达到下界,就直接爆搜即可。可以发现,当\(N\ge6\)时,是可以找到的。
这样一通剪完枝后,其实已经跑的非常快了,不算打表的可以排到\(Rank\ 1\)。
Code
#include <bits/stdc++.h>
#include <unistd.h>
using namespace std;
vector < int > leg[1005][1005];
int n, tot, t, fl, nw, lim, mn_f = 1e9, mn_s = 1e9, vis[20005], pr[5005], bk[20005], hd[100005], to[200005], nxt[200005], used[10005], cnt[1005][1005], out[1005][1005];
int read()
{
int x = 0, fl = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
return x * fl;
}
void add(int x, int y)
{
tot ++ ;
to[tot] = y;
nxt[tot] = hd[x];
hd[x] = tot;
return;
}
void dfs3(int x, int y)
{
if (y == n + 1)
{
x ++ ;
y = 2;
}
if (x == n + 1)
{
fl = 1;
mn_s = nw;
if (nw == lim)
{
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= n; j ++ )
printf("%d%c", cnt[i][j], j == n ? '\n' : ' ');
exit(0);
}
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= n; j ++ )
out[i][j] = cnt[i][j];
return;
}
for (int p = 0; p < (int)leg[cnt[x - 1][y]][cnt[x][y - 1]].size(); p ++ )
{
int d = leg[cnt[x - 1][y]][cnt[x][y - 1]][p];
if (used[d]) continue;
used[d] = 1;
cnt[x][y] = d;
dfs3(x, y + 1);
used[d] = 0;
cnt[x][y] = 0;
}
return;
}
void dfs2(int x, int sum)
{
if (sum > mn_f || sum > mn_s || (sum > lim && n >= 6)) return;
if (x == n + 1)
{
fl = 0;
nw = sum;
dfs3(2, 2);
if (fl) mn_f = sum;
return;
}
for (int i = hd[cnt[x - 1][1]]; i; i = nxt[i])
{
int y = to[i];
if (used[y]) continue;
used[y] = 1;
cnt[x][1] = y;
dfs2(x + 1, sum + y);
used[y] = 0;
}
return;
}
void dfs1(int x, int sum)
{
if (x == n + 1)
{
dfs2(2, sum);
return;
}
for (int i = hd[cnt[1][x - 1]]; i; i = nxt[i])
{
int y = to[i];
if (used[y]) continue;
used[y] = 1;
cnt[1][x] = y;
dfs1(x + 1, sum + y);
used[y] = 0;
}
return;
}
int main()
{
n = read();
if (n % 2) lim = n * (2 * n - 1);
else lim = n * (2 * n - 1) + 1;
if (n == 1 || n == 3)
{
puts("NO");
return 0;
}
for (int i = 2; i <= 2 * n * n; i ++ )
{
if (!vis[i])
{
vis[i] = i;
pr[ ++ t] = i;
bk[pr[t]] = 1;
}
for (int j = 1; j <= t; j ++ )
{
if (pr[j] > vis[i] || i * pr[j] > 2 * n * n) break;
vis[i * pr[j]] = pr[j];
}
}
for (int i = 1; i <= n * n; i ++ )
for (int j = t; j >= 1; j -- )
if (pr[j] > i && (i < pr[j] - i))
add(i, pr[j] - i), add(pr[j] - i, i);
for (int i = 1; i <= n * n; i ++ )
for (int j = 1; j <= n * n; j ++ )
for (int k = 1; k <= n * n; k ++ )
if (bk[i + k] && bk[j + k])
leg[i][j].push_back(k);
cnt[1][1] = out[1][1] = 1; used[1] = 1;
dfs1(2, 1);
for (int i = 1; i <= n; i ++ )
for(int j = 1; j <= n; j ++ )
printf("%d%c", out[i][j], j == n ? '\n' : ' ');
return 0;
}