Sudoku POJ - 3074

原题链接

考察:dfs+剪枝

错误思路:

       一开始是想按行搜或者九宫格搜.但是只能想到每九个格子搜一次看有哪些数字,再dfs

没想到位运算优化.

正确思路:

       预处理每行、每列、每个九宫格有哪些数字.用8位二进制数表示哪些数字被使用了(本每次找九格用位运算优化到O(1)).每次搜索找到分支最少的格子搜(两重for,这个剪枝优化很强,不用担心TLE).

       这里还涉及的操作是快速求二进制中1的个数,可以预处理,还有求1再哪些位置,也可以配合lowbit预处理.

  1 #include <iostream>
  2 #include <cstring>
  3 #include <cstdio>
  4 using namespace std;
  5 const int N = 10,M = 1<<9;
  6 char mp[N][N],s[N*9];
  7 int pre[M],tot,sum[M];
  8 int row[N],col[N],cell[N/3][N/3],all;
  9 int lowbit(int x)
 10 {
 11     return x&-x;
 12 }
 13 void change()
 14 {
 15     int len = strlen(s+1);
 16     for(int i=0;i<9;i++) row[i] = all,col[i] = all;
 17     for(int i=0;i<3;i++)
 18       for(int j=0;j<3;j++) cell[i][j] = all;
 19     for(int i=1,j=0;i<=len;i++)
 20     {
 21         mp[j][(i-1)%9] = s[i];
 22         if(s[i]=='.') tot++;
 23         if(i%9==0) j++;
 24     }
 25     for(int i=0;i<9;i++)
 26       for(int j=0;j<9;j++)
 27       {
 28           if(mp[i][j]!='.')
 29           {
 30               int x = mp[i][j]-'1';
 31               row[i]-=1<<x;
 32               cell[i/3][j/3] -= 1<<x;//为0表示该数字不可用.
 33           }
 34           if(mp[j][i]!='.')
 35           {
 36               int x = mp[j][i]-'1';
 37               col[i]-=1<<x;
 38           }
 39       }
 40 }
 41 int get(int x,int y)
 42 {
 43     return col[y]&row[x]&cell[x/3][y/3];
 44 }
 45 void draw(int i,int j,int x,bool isre)
 46 {
 47     if(isre)
 48     {
 49         row[i]-=(1<<x);
 50         col[j]-=(1<<x);
 51         cell[i/3][j/3]-=(1<<x);
 52         mp[i][j] = x+'1';
 53     }else{
 54         row[i]|=(1<<x);
 55         col[j]|=(1<<x);
 56         cell[i/3][j/3]|=(1<<x);
 57         mp[i][j]='.';
 58     }
 59 }
 60 bool dfs(int now)
 61 {
 62     if(!now) return 1;
 63     //找枝条最少的点
 64     int temp = all,x=0,y=0;
 65     for(int i=0;i<9;i++)
 66       for(int j=0;j<9;j++)
 67         if(mp[i][j]=='.')
 68         {
 69             int t = get(i,j);
 70             if(sum[t]<temp)
 71             {
 72                 temp = sum[t];
 73                 x = i,y = j;
 74             }
 75         }
 76     int s = get(x,y);
 77     for(int i=s;i;i-=lowbit(i))
 78     {
 79         int j = pre[lowbit(i)];
 80         draw(x,y,j,1);
 81         if(dfs(now-1)) return 1;
 82         draw(x,y,j,0);
 83     }
 84     return 0; 
 85 }
 86 int main()
 87 {
 88     all = (1<<9)-1;
 89     for(int i=0;i<9;i++) pre[1<<i] = i;
 90     for(int i=1;i<all;i++)
 91       for(int j=0;j<9;j++)
 92          sum[i]+=i>>j&1;
 93     while(scanf("%s",s+1)!=EOF&&s[1]!='e')
 94     {
 95         tot = 0;
 96         change();
 97         dfs(tot);
 98         int idx = 1;
 99         for(int i=0;i<9;i++)
100           for(int j=0;j<9;j++)
101               s[idx++] = mp[i][j];
102         s[idx] = '\0';
103         printf("%s\n",s+1);
104     }
105     return 0;
106 }

 

posted @ 2021-03-09 08:09  acmloser  阅读(34)  评论(0编辑  收藏  举报