解数独算法的实现——剪枝优化
最近人工智能做个小实验,组队选了个数独游戏,顺便研究了一下。解数独感觉主流思想也就是深搜回溯了吧,优化就是各种剪枝方法。
1 引言
数独起源于18世纪初瑞士数学家欧拉等人研究的拉丁方阵(Latin Square),曾风靡日本和英国。现有解法包括基础解法:摒除法,余数法,进阶解法:区块摒除法(Locked Candidates)、数组法(Subset)、四角对角线(X-Wing)、唯一矩形(Unique Rectangle)、全双值坟墓(Bivalue Universal Grave)、单数链(X-Chain)、异数链(XY-Chain)及其他数链的高级技巧等等。已发展出来的方法有近百种之多。本解法中使用了余数法和数组法。
2 算法原理
用个9*9的vector保存整个游戏,注意的是这里的9*9不是9行9列,而是9个九宫格,起初考虑时是希望能更简洁,不过最终都会用到位置所在的行数,列数和所在九宫格,似乎存行列只需要通过行列求9宫格更方便,而存9*九宫格要通过第几个九宫格和九宫格中第几个位置来求行列,更复杂(—_—)! 算是个坑,以后有同学在做的话,要注意了哦!
算法的流程也很简单,上个流程图,
首先,数据结构是9*9的游戏盘,然后为每个位置的可选数字维护一个集合(set),每次更新数字时会同时更新相关位置的集合,稍后会讲到。还有个3*9的集合,是每行每列每个九宫格的可选数字的集合。
然后说说剪枝算法,前面说到用了余数法和数组法,要声明的一点是,这些解法都是人类的解法,即为人类如何选取数字而达到最快解出题目,不过我们计算机深搜是直接选取可选数字最少的位置,所以可能叫法有争议,在此就不深入讨论了。
(1)余数法
每次选择数字后,都要删除对应位置的可选数,及同行同列同九宫格中不能再次选择相同的数字了,所以更新下面图片中灰色位置的集合,如果有位置的可选数字为空,则回溯。
(2)数组法
或许有的时候更新后所有位置都有可选数字,但此时已经出现冲突,例如,一行中只有2个空位,而他们的可选数字相同,这时一定无解了,所以需要回溯。
为了检测,我们需要将更新过的位置所处的行列九宫格的集合全部更新,(灰色位置包括的行列九宫格,实际就是所有行列和5个九宫格),更新方法就是对该集合包含的所有位置的可选数+已选数求并集,得出集合大小小于9则产生了冲突,回溯。
3 代码
1 #include <iostream> 2 #include <vector> 3 #include <set> 4 #include <cstdlib> 5 #include <ctime> 6 #include <fstream> 7 8 using namespace std; 9 10 11 class Sudoku{ 12 public: 13 vector<vector<int> > numMap;//9*9 0 14 vector<vector<set<int> > > availableNum;//row,col,nine 1~9 15 vector<vector<set<int> > > everyNum; 16 17 Sudoku(vector<vector<int> > v){ 18 initMap(v); 19 } 20 21 bool updateMap(int small,int big,int value){ //9*9 vector everyNum[big][small] = value 22 if(numMap[big][small]==0){ //更新游戏盘,删除相关位置中的value 23 set<int> &row = availableNum[0][getRow(small,big)]; 24 set<int> &col = availableNum[1][getCol(small,big)]; 25 set<int> &nin = availableNum[2][big]; 26 set<int>::iterator rowIt = row.find(value); 27 set<int>::iterator colIt = col.find(value); 28 set<int>::iterator ninIt = nin.find(value); 29 if(rowIt!=row.end()&& 30 colIt!=col.end()&& 31 ninIt!=nin.end()){ 32 row.erase(rowIt); 33 col.erase(colIt); 34 nin.erase(ninIt); 35 numMap[big][small] = value; 36 set<int> s; 37 everyNum[big][small] =s; //选中后集合中只有一个选中的数字本身 38 everyNum[big][small].insert(value); 39 if(updateEve(small,big,value)) 40 return true; 41 } 42 } 43 return false; 44 } 45 bool updateEve(int small,int big,int value){ 46 for(int i=0;i!=9;++i){ 47 if(numMap[big][i]==0){ 48 set<int>::iterator it = everyNum[big][i].find(value); 49 if(it!=everyNum[big][i].end()){ 50 everyNum[big][i].erase(it); 51 } 52 if(everyNum[big][i].size()==0) 53 return false; 54 } 55 } 56 int r = getRow(small,big); 57 for(int j=0;j!=3;++j){ 58 for(int k=0;k!=3;++k){ 59 int a=r/3*3+j; 60 int b=r%3*3+k; 61 if(numMap[a][b]==0){ 62 set<int>::iterator it = everyNum[a][b].find(value); 63 if(it!=everyNum[a][b].end()){ 64 everyNum[a][b].erase(it); 65 } 66 if(everyNum[a][b].size()==0) 67 return false; 68 69 } 70 } 71 for(int j=0;j!=3;++j){ 72 for(int k=0;k!=3;++k){ 73 int a=big%3+j*3; 74 int b=small%3+k*3; 75 if(numMap[a][b]==0){ 76 set<int>::iterator it = everyNum[a][b].find(value); 77 if(it!=everyNum[a][b].end()){ 78 everyNum[a][b].erase(it); 79 } 80 if(everyNum[a][b].size()==0) 81 return false; 82 83 } 84 } 85 return true; 86 } 87 bool check(){ //数组法检查 88 for(int i=0;i!=9;++i){ 89 set<int> s; 90 for(int j=0;j!=9;++j){ 91 for(auto it:everyNum[i][j]){ 92 s.insert(it); 93 } 94 } 95 if(s.size()!=9) 96 return false; 97 } 98 for(int r=0;r!=9;++r){ 99 set<int> s; 100 for(int j=0;j!=3;++j){ 101 for(int k=0;k!=3;++k){ 102 int a=r/3*3+j; 103 int b=r%3*3+k; 104 for(auto it:everyNum[a][b]){ 105 s.insert(it); 106 } 107 } 108 } 109 if(s.size()!=9) 110 return false; 111 } 112 for(int c=0;c!=9;++c){ 113 set<int> s; 114 for(int j=0;j!=3;++j){ 115 for(int k=0;k!=3;++k){ 116 int a=j*3+c/3; 117 int b=k*3+c%3; 118 for(auto it:everyNum[a][b]){ 119 s.insert(it); 120 } 121 } 122 } 123 if(s.size()!=9) 124 return false; 125 } 126 return true; 127 } 128 void initMap(vector<vector<int> > vv){ 129 vector<int> v(9,0); 130 set<int> s; 131 for(int i=1;i!=10;++i){ 132 s.insert(i); 133 } 134 vector<set<int> > sv(9,s); 135 availableNum = vector<vector<set<int> > >(3,sv); 136 everyNum = vector<vector<set<int> > >(9,sv); 137 numMap = vector<vector<int> >(9,v); 138 for(int i=0;i!=9;++i){ 139 vector<int> tmp = vv[i]; 140 for(int j=0;j!=3;++j){ 141 for(int k=0;k!=3;++k){ 142 numMap[i/3*3+j][i%3*3+k]=tmp[j*3+k]; 143 if(tmp[j*3+k]!=0){ 144 int value = tmp[j*3+k]; 145 set<int> a; 146 everyNum[i/3*3+j][i%3*3+k] = a; 147 set<int> &row = availableNum[0][getRow(i%3*3+k,i/3*3+j)]; 148 set<int> &col = availableNum[1][getCol(i%3*3+k,i/3*3+j)]; 149 set<int> &nin = availableNum[2][i/3*3+j]; 150 set<int>::iterator rowIt = row.find(value); 151 set<int>::iterator colIt = col.find(value); 152 set<int>::iterator ninIt = nin.find(value); 153 row.erase(rowIt); 154 col.erase(colIt); 155 nin.erase(ninIt); 156 } 157 } 158 } 159 } 160 showMap(); 161 set<int> tmp; 162 for(int i=0;i!=9;++i){ 163 for(int j=0;j!=9;++j){ 164 if(numMap[i][j]==0){ 165 everyNum[i][j] = getEve(j,i); 166 } 167 else{ 168 everyNum[i][j] = tmp; 169 everyNum[i][j].insert(numMap[i][j]); 170 } 171 172 } 173 } 174 showEve(); 175 } 176 set<int> getEve(int small,int big){ 177 set<int> &row = availableNum[0][getRow(small,big)]; 178 set<int> &col = availableNum[1][getCol(small,big)]; 179 set<int> &nin = availableNum[2][big]; 180 set<int> res; 181 for(auto it:row){ 182 if(col.find(it)!=col.end()&&nin.find(it)!=nin.end()){ 183 res.insert(it); 184 } 185 } 186 return res; 187 } 188 189 pair<int,int> getMin(){ 190 pair<int,int> res = make_pair(9,9); 191 int Min =10; 192 for(int i=0;i!=9;++i){ 193 for(int j=0;j!=9;++j){ 194 if(numMap[i][j]==0&&everyNum[i][j].size()<Min){ 195 res = make_pair(i,j); 196 Min = everyNum[i][j].size(); 197 } 198 } 199 } 200 return res; 201 } 202 int getEveNum(int small,int big){ 203 set<int> &row = availableNum[0][getRow(small,big)]; 204 set<int> &col = availableNum[1][getCol(small,big)]; 205 set<int> &nin = availableNum[2][big]; 206 int res=0; 207 for(auto it:row){ 208 if(col.find(it)!=col.end()&&nin.find(it)!=nin.end()){ 209 res++; 210 } 211 } 212 return res; 213 } 214 int getRow(int small,int big){ 215 return big/3*3+small/3; 216 } 217 int getCol(int small,int big){ 218 return big%3*3+small%3; 219 } 220 void showMap(){ 221 for(int i=0;i!=9;++i){ 222 cout<<" "; 223 if(i%3==0){ 224 cout<<"---------------------"<<endl; 225 cout<<" "; 226 } 227 for(int j=0;j!=3;++j){ 228 cout<<"|"; 229 for(int k=0;k!=3;++k){ 230 if(numMap[i/3*3+j][i%3*3+k]==0) 231 cout<<" "; 232 else 233 cout<<numMap[i/3*3+j][i%3*3+k]<<" "; 234 } 235 } 236 cout<<"|"<<endl; 237 } 238 cout<<" "; 239 cout<<"---------------------"<<endl; 240 241 } 242 void showEve(){ 243 for(int i=0;i!=9;++i){ 244 cout<<" "; 245 if(i%3==0){ 246 cout<<"---------------------"<<endl; 247 cout<<" "; 248 } 249 for(int j=0;j!=3;++j){ 250 cout<<"|"; 251 for(int k=0;k!=3;++k){ 252 if(numMap[i/3*3+j][i%3*3+k]==0){ 253 for(auto it:everyNum[i/3*3+j][i%3*3+k]) 254 cout<<it; 255 cout<<" "; 256 } 257 else 258 cout<<numMap[i/3*3+j][i%3*3+k]<<" "; 259 } 260 } 261 cout<<"|"<<endl; 262 } 263 cout<<" "; 264 cout<<"---------------------"<<endl; 265 266 } 267 }; 268 269 bool solu(Sudoku mSudoku,int small,int big,int value); 270 bool solu(Sudoku mSudoku); 271 bool solu1(Sudoku mSudoku,int small,int big,int value); 272 bool solu1(Sudoku mSudoku){ 273 pair<int,int> p = mSudoku.getMin(); 274 for(auto i:mSudoku.everyNum[p.first][p.second]){ 275 if(solu1(mSudoku,p.second,p.first,i)) 276 return true; 277 } 278 return false; 279 } 280 bool solu(Sudoku mSudoku){ 281 pair<int,int> p = mSudoku.getMin(); 282 for(auto i:mSudoku.everyNum[p.first][p.second]){ 283 if(solu(mSudoku,p.second,p.first,i)) 284 return true; 285 } 286 return false; 287 } 288 289 bool solu(Sudoku mSudoku,int small,int big,int value){ 290 if(!mSudoku.updateMap(small,big,value)) 291 return false; 292 pair<int,int> p = mSudoku.getMin(); 293 if(p==make_pair(9,9)){ 294 mSudoku.showMap(); 295 return true; 296 } 297 for(auto i:mSudoku.everyNum[p.first][p.second]){ 298 if(solu(mSudoku,p.second,p.first,i)) 299 return true; 300 } 301 return false; 302 } 303 304 bool solu1(Sudoku mSudoku,int small,int big,int value){ 305 if(!mSudoku.updateMap(small,big,value)||!mSudoku.check()) 306 return false; 307 pair<int,int> p = mSudoku.getMin(); 308 if(p==make_pair(9,9)){ 309 mSudoku.showMap(); 310 return true; 311 } 312 for(auto i:mSudoku.everyNum[p.first][p.second]){ 313 if(solu1(mSudoku,p.second,p.first,i)) 314 return true; 315 } 316 return false; 317 } 318
1 int main() 2 { 3 vector<int> a(9,0); 4 vector<vector<int> > b(9,a); 5 vector<vector<vector<int> > > v(95,b); 6 vector<double> tim(95,0); 7 vector<double> tim1(95,0); 8 string str; 9 for(int i=0;i!=95;++i){ 10 cin>>str; 11 for(int j=0;j!=9;++j){ 12 for(int k=0;k!=9;++k){ 13 if(str[j*9+k]=='.') 14 v[i][j][k]=0; 15 else 16 v[i][j][k]=str[j*9+k]-'0'; 17 } 18 } 19 Sudoku mSudoku(v[i]); 20 //unsigned long start = ::GetTickCount(); 21 //cout<<start; 22 clock_t start = clock(); 23 solu(mSudoku); 24 clock_t en = clock(); 25 tim[i] = (double)(en - start) / CLOCKS_PER_SEC; 26 start = clock(); 27 solu1(mSudoku); 28 en = clock(); 29 tim1[i] = (double)(en - start) / CLOCKS_PER_SEC; 30 } 31 ofstream file("spendTime.txt"); 32 for(int i=0;i!=95;++i){ 33 file<<tim[i]<<","; 34 } 35 file<<endl; 36 file.close(); 37 ofstream file1("spendTime1.txt"); 38 for(int i=0;i!=95;++i){ 39 file1<<tim1[i]<<","; 40 } 41 file1.close(); 42 return 0; 43 }
写了个main()从测试从网上下载下来的95个用例,记录个时间,用python画出来。
可以看到对于一些回溯次数较多的用例,剪枝效果还是很不错的。
附:代码及相关文件