Sicily1317-Sudoku-位运算暴搜
最终代码地址:https://github.com/laiy/Datastructure-Algorithm/blob/master/sicily/1317.c
这题博主刷了1天,不是为了做出来,AC之后在那死磕性能...
累积交了45份代码,纪念一下- -
以上展示了从1.25s优化到0.03s的艰苦历程...
来看题目吧,就是一个数独求解的题:
1317. Sudoku
Constraints
Time Limit: 10 secs, Memory Limit: 32 MB
Description
Sudoku is a placement puzzle. The goal is to enter a symbol in each cell of a grid, most frequently a 9 x 9 <tex2html_verbatim_mark>grid made up of 3 x 3 <tex2html_verbatim_mark>subgrids. Each row, column and subgrid must contain only one instance of each symbol. Sudoku initially became popular in Japan in 1986 and attained international popularity in 2005.
The word Sudoku means ``single number" in Japanese. The symbols in Sudoku puzzles are often numerals, but arithmetic relationships between numerals are irrelevant.
According to wikipedia:
The number of valid Sudoku solution grids for the standard 9 x 9 <tex2html_verbatim_mark>grid was calculated by Bertram Felgenhauer in 2005 to be 6,670,903,752,021,072,936,960, which is roughly the number of micrometers to the nearest star. This number is equal to 9! * 722 * 27 * 27, 704, 267, 971 <tex2html_verbatim_mark>, the last factor of which is prime. The result was derived through logic and brute force computation. The number of valid Sudoku solution grids for the 16 x 16 <tex2html_verbatim_mark>derivation is not known.
Write a program to find a solution to a 9 x 9 <tex2html_verbatim_mark>Sudoku puzzle given a starting configuration.
Input
The first line will contain an integer specifying the number of puzzles to be solved. The remaining lines will specify the starting configuration for each of the puzzles. Each line in a starting configuration will have nine characters selected from the numerals 1-9 and the underscore which indicates an empty cell.
Output
For each puzzle, the output should specify the puzzle number (starting at one) and describe the solution characteristics. If there is a single solution, it should be printed. Otherwise, a message indicating whether there are no solutions or multiple solutions should be printed. The output should be similar to that shown below. All input cases have less than 10,000 solutions.
Sample Input
3 ________4 1____9_7_ __37_28__ ____7_26_ 4_______8 _91_6____ __42_36__ _3_14___9 9________ 7_9__2___ 3_____891 ___39___4 48__6____ __5___6__ ____4__23 2___57___ 568_____7 ___8__4_2 82_______ ___5__2__ __6_4_7__ _5___1_7_ 9_2_5_4_1 _3_8_6_9_ __3_6_1__ __5__2___ _______34
Sample Output
Puzzle 1 has 6 solutions Puzzle 2 solution is 719482365 324675891 856391274 482563719 135729648 697148523 243957186 568214937 971836452 Puzzle 3 has no solution
刚看到这题,思路很清楚,启发式DFS,按每个空格可以填入数字的可能数量作为权重,每次取权重最小的搜索空间进行拓展,这样可以在很大程度上保证剪枝剪的是最大的。
比如,现在有5个搜索空间(即5个没有填入数字的空格),每个空间的权重为(1,2,3,4,5)。
那么DFS的时候树伸出去的枝叶我们喜欢尽量的少,这样在递归的时候回溯的时候剪掉的枝叶会更多,而每个职业延伸出去的代价可以认为是一样的。
假设从根到叶延伸(式子中从左到右)的形式为1 * 2 * 3 * 4 * 5,那么在第二个空间搜索的时候如果回溯了是不是剪掉了一半的枝?
而如果反过来搜5 * 4 * 3 * 2 * 1这样来搜的话回溯只是剪掉了1/5的枝。
这个原理和马周游的Warnsdorff's rule其实是一样的。
然后来考虑一下数据结构,为了方便计算,我们用3个数组来记录每个行,列,块数字的占用情况。(row_space, col_space, block_space)
用board来记录当前数独填入状态,在用nodes来记录搜索空间的权重,比如nodes[i][j] = 4意味着第i行第j列权重为4。
代码如下:
1 #include <cstdio> 2 #include <cstring> 3 4 short board[10][10]; 5 bool col_space[10][10], row_space[10][10], block_space[3][3][10]; 6 short record[10][10]; 7 short nodes[10][10]; 8 int solutions, i, j, min, weight, k, m, v, c, update_value; 9 10 inline void update_weight(int &i, int &j) { 11 if (nodes[i][j] == -1 || row_space[i][update_value] || col_space[j][update_value] || \ 12 block_space[(i - 1) / 3][(j - 1) / 3][update_value]) 13 return; 14 nodes[i][j]--; 15 } 16 17 inline void count_weight(int i, int j) { 18 static int record_i, record_j; 19 record_i = i, record_j = j; 20 for (m = 1; m < 10; m++) 21 if (m != record_j) 22 update_weight(i, m); 23 for (m = 1; m < 10; m++) 24 if (m != record_i) 25 update_weight(m, j); 26 i = ((i - 1) / 3) * 3 + 1; 27 j = ((j - 1) / 3) * 3 + 1; 28 for (m = i; m < i + 3; m++) 29 for (v = j; v < j + 3; v++) 30 if (m != record_i && v != record_j) 31 update_weight(m, v); 32 } 33 34 inline void heuristic_dfs() { 35 if (!c) { 36 if (!solutions) 37 memcpy(record, board, sizeof(board)); 38 solutions++; 39 return; 40 } 41 min = 10; 42 short record_i, record_j, k, board_record_i, board_record_j, temp[10][10]; 43 for (i = 1; i <= 9; i++) 44 for (j = 1; j <= 9; j++) 45 if (nodes[i][j] != -1 && nodes[i][j] < min) 46 min = nodes[i][j], record_i = i, record_j = j; 47 board_record_i = (record_i - 1) / 3, board_record_j = (record_j - 1) / 3; 48 c--; 49 nodes[record_i][record_j] = -1; 50 memcpy(temp, nodes, sizeof(nodes)); 51 for (k = 1; k < 10; k++) 52 if (!(row_space[record_i][k] || col_space[record_j][k] || block_space[board_record_i][board_record_j][k])) { 53 update_value = k; 54 count_weight(record_i, record_j); 55 row_space[record_i][k] = col_space[record_j][k] = block_space[board_record_i][board_record_j][k] = true; 56 board[record_i][record_j] = k; 57 heuristic_dfs(); 58 row_space[record_i][k] = col_space[record_j][k] = block_space[board_record_i][board_record_j][k] = false; 59 memcpy(nodes, temp, sizeof(nodes)); 60 } 61 c++; 62 nodes[record_i][record_j] = min; 63 } 64 65 int main() { 66 int t, count = 1; 67 scanf("%d", &t); 68 char input[10]; 69 while (t--) { 70 if (count != 1) 71 printf("\n"); 72 memset(board, -1, sizeof(board)); 73 memset(nodes, -1, sizeof(nodes)); 74 memset(col_space, 0, sizeof(col_space)); 75 memset(row_space, 0, sizeof(row_space)); 76 memset(block_space, 0, sizeof(block_space)); 77 solutions = 0; 78 for (i = 1; i <= 9; i++) { 79 scanf("%s", input); 80 for (j = 0; j < 9; j++) { 81 if (input[j] != '_') 82 board[i][j + 1] = input[j] - '0', row_space[i][board[i][j + 1]] = true, col_space[j + 1][board[i][j + 1]] = true, \ 83 block_space[(i - 1) / 3][(j) / 3][board[i][j + 1]] = true; 84 } 85 } 86 c = 0; 87 for (i = 1; i <= 9; i++) 88 for (j = 1; j <= 9; j++) 89 if (board[i][j] == -1) { 90 weight = 9; 91 for (k = 1; k < 10; k++) 92 if (row_space[i][k] || col_space[j][k] || block_space[(i - 1) / 3][(j - 1) / 3][k]) 93 weight--; 94 nodes[i][j] = weight; 95 c++; 96 } 97 heuristic_dfs(); 98 if (!solutions) 99 printf("Puzzle %d has no solution\n", count++); 100 else if (solutions > 1) 101 printf("Puzzle %d has %d solutions\n", count++, solutions); 102 else { 103 printf("Puzzle %d solution is\n", count++); 104 for (i = 1; i <= 9; i++) { 105 for (j = 1; j <= 9; j++) 106 printf("%d", record[i][j]); 107 printf("\n"); 108 } 109 } 110 } 111 return 0; 112 }
这种思路不断做下去极限是0.16s。
那博主开头的0.03s是怎么做到的呢?
答案是:用matrix67的八皇后那样用位运算提高状态空间计算的效率直接暴搜就可以了。
事实确实如此!(我受到了惊吓,启发式DFS比DFS慢这么多)
我要代码按我的启发式思路走下去在数据结构和执行性能上打了很大折扣,我需要额外的计算来维护启发式的执行,如果启发式带来的优势不能大于算法变复杂带来的劣势的话会得不偿失!
来考虑一下为了执行启发式需要的代价:
1. 初始所有状态的weight,遍历所有结点。
2. 每次dfs选取最小weight结点的代价。
3. 每次延伸枝的时候我们更改了某一个结点的状态,那么对应的行,列,块的结点依次需要判断是否需要更新weight。
而事实证明维护以上状态需要的计算带来的负担是大于启发式带来剪枝优化的收益的。
接下来介绍一下什么是位运算状态暴搜。
其实就是用位来描述状态,语义是由人来赋予的。这里用9位来描述某一行/列/块来描述这行/列/块数字的占用情况。
比如row_space[0] = 111111111,代表第一行9个数字都是可以搜索的空间,如果变成111111110则表示第一行的1这个数字已经被占用。
这么做有什么好处?效率和空间都暴涨!
如何获得当前结点可以搜索的空间?
space = row_space[i] & col_space[j] & block_space[block_index]就行了不是吗?
如何遍历搜索空间?
用space & (-space)来提取出最后一个1代表的搜索空间,执行,然后用异或运算来更新状态空间即可。
这样暴搜性能可以达到0.03s!
代码如下:
1 #include <cstdio> 2 #include <cstring> 3 4 int solutions; 5 int row_space[9], col_space[9], block_space[9]; 6 int board[9][9], record[9][9]; 7 int digit_table[(1 << 8) + 1]; 8 9 void dfs(int i, int j) { 10 while (i >= 0 && board[i][j]) 11 if (j) 12 j--; 13 else 14 i--, j = 8; 15 if (i == -1) { 16 if (!solutions) 17 memcpy(record, board, sizeof(board)); 18 solutions++; 19 return; 20 } 21 int block_index = i / 3 * 3 + j / 3; 22 int space = row_space[i] & col_space[j] & block_space[block_index]; 23 int put; 24 while (space) { 25 put = space & (-space); 26 space ^= put; 27 board[i][j] = digit_table[put]; 28 row_space[i] ^= put; 29 col_space[j] ^= put; 30 block_space[block_index] ^= put; 31 dfs(i, j); 32 row_space[i] ^= put; 33 col_space[j] ^= put; 34 block_space[block_index] ^= put; 35 } 36 board[i][j] = 0; 37 } 38 39 int main() { 40 int t, count = 1, i, j, init, mask; 41 scanf("%d", &t); 42 char input[10]; 43 init = (1 << 9) - 1; 44 for (int i = 0; i < 9; i++) 45 digit_table[1 << i] = i + 1; 46 while (t--) { 47 if (count != 1) 48 printf("\n"); 49 solutions = 0; 50 for (i = 0; i < 9; i++) 51 row_space[i] = col_space[i] = block_space[i] = init; 52 memset(board, 0, sizeof(board)); 53 for (i = 0; i < 9; i++) { 54 scanf("%s", input); 55 for (j = 0; j < 9; j++) 56 if (input[j] != '_') { 57 board[i][j] = input[j] - '0'; 58 mask = ~(1 << (board[i][j] - 1)); 59 row_space[i] &= mask; 60 col_space[j] &= mask; 61 block_space[i / 3 * 3 + j / 3] &= mask; 62 } 63 } 64 dfs(8, 8); 65 if (!solutions) 66 printf("Puzzle %d has no solution\n", count++); 67 else if (solutions > 1) 68 printf("Puzzle %d has %d solutions\n", count++, solutions); 69 else { 70 printf("Puzzle %d solution is\n", count++); 71 for (i = 0; i < 9; i++) { 72 for (j = 0; j < 9; j++) 73 printf("%d", record[i][j]); 74 printf("\n"); 75 } 76 } 77 } 78 return 0; 79 }