洛谷 P1123 取数游戏
题目: 洛谷 P1123 取数游戏: https://www.luogu.org/problemnew/show/P1123
题目描述
一个N×MN \times MN×M的由非负整数构成的数字矩阵,你需要在其中取出若干个数字,使得取出的任意两个数字不相邻(若一个数字在另外一个数字相邻888个格子中的一个即认为这两个数字相邻),求取出数字和最大是多少。
输入输出格式
输入格式:第1行有一个正整数TTT,表示了有TTT组数据。
对于每一组数据,第一行有两个正整数NNN和MMM,表示了数字矩阵为NNN行MMM列。
接下来NNN行,每行MMM个非负整数,描述了这个数字矩阵。
输出格式:TTT行,每行一个非负整数,输出所求得的答案。
输入输出样例
3 4 4 67 75 63 10 29 29 92 14 21 68 71 56 8 67 91 25 2 3 87 70 85 10 3 17 3 3 1 1 1 1 99 1 1 1 1
271 172 99
说明
对于第1组数据,取数方式如下:
[67] 75 63 10
29 29 [92] 14
[21] 68 71 56
8 67 [91] 25
对于20%20\%20%的数据,N,M≤3N, M≤3N,M≤3;
对于40%40\%40%的数据,N,M≤4N,M≤4N,M≤4;
对于60%60\%60%的数据,N,M≤5N, M≤5N,M≤5;
对于100%100\%100%的数据,N,M≤6,T≤20N, M≤6,T≤20N,M≤6,T≤20。
在拿到一道题时,首先是审题,明白题目大意。
然后,自然而然地我们把目光转向数据范围。
~~~~咦!我们发现数据范围很小。那么我们就暴力搜索。
这里,有个小尴尬的地方......虽然对结果没有多大影响,但会影响代码的美观与简洁。
请大家移步至以下代码第23~31与37~45行。
可以发现我写了一长串。这是我用于判断点(x,y)与其周围的点是否可以走用的。
如果你也用这样的方法,一定要注意:
不可以使用 bool类型来帮助你判断。因为如果一个点被覆盖了两次,而你在回溯时一次就将这个点改回去,会导致错误。
使用 int类型来记录这个点被覆盖了几次。不用担心出现负值,因为你将它+1后,回溯时也有且只会将它-1,不可能出现多减的情况。
1 // 2 #include <bits/stdc++.h> 3 using namespace std; 4 typedef long long ll; 5 #define ri register ll 6 7 ll t,n,m,ans,sum; 8 ll a[10][10]; 9 ll vis[10][10]; 10 11 12 void dfs(ll x,ll y) 13 { 14 if(x>n||y>m){ans=max(ans,sum);return;} 15 for(ri i=x;i<=9;i++) 16 { 17 for(ri j=1;j<=9;j++) 18 { 19 if(vis[i][j])continue; 20 21 sum+=a[i][j]; 22 23 vis[i-1][j-1]+=1; 24 vis[i-1][j]+=1; 25 vis[i-1][j+1]+=1; 26 vis[i][j-1]+=1; 27 vis[i][j]+=1; 28 vis[i][j+1]+=1; 29 vis[i+1][j-1]+=1; 30 vis[i+1][j]+=1; 31 vis[i+1][j+1]+=1; 32 33 dfs(i,j); 34 35 sum-=a[i][j]; 36 37 vis[i-1][j-1]-=1; 38 vis[i-1][j]-=1; 39 vis[i-1][j+1]-=1; 40 vis[i][j-1]-=1; 41 vis[i][j]-=1; 42 vis[i][j+1]-=1; 43 vis[i+1][j-1]-=1; 44 vis[i+1][j]-=1; 45 vis[i+1][j+1]-=1; 46 } 47 } 48 } 49 50 void work() 51 { 52 ans=0,sum=0; 53 memset(vis,0,sizeof(vis)); 54 memset(a,0,sizeof(a)); 55 cin>>n>>m; 56 for(ri i=1;i<=n;i++) 57 { 58 for(ri j=1;j<=m;j++) 59 { 60 cin>>a[i][j]; 61 } 62 } 63 dfs(1,1); 64 cout<<ans<<'\n'; 65 } 66 67 signed main() 68 { 69 ios::sync_with_stdio(0),cin.tie(0); 70 cin>>t; 71 while(t--) 72 { 73 work(); 74 } 75 return 0; 76 } 77 //
这里,你会发现超时了。
那么接下来就是优化代码的时候了。
在优化之前,我们先来看一看那两行长长的代码。当你写下它时,你会觉得,啊!它们好长好长....要粘贴复制那么多次!
那我们有什么方法简化吗? —of course。
我们把赋值的繁琐过程改到判断中:
这是重写后的 dfs()函数。
可以看到第8行的判断使我们可以只用11,16行两句话完成回溯过程。
1 void dfs(ll x,ll y) 2 { 3 if(x>n||y>m){ans=max(ans,sum);return;} 4 for(ri i=x;i<=8;i++) 5 { 6 for(ri j=1;j<=8;j++) 7 { 8 if(vis[i][j]||vis[i-1][j-1]||vis[i-1][j]||vis[i-1][j+1]||vis[i][j-1]||vis[i][j+1]||vis[i+1][j-1]||vis[i+1][j]||vis[i+1][j+1])continue; 9 10 sum+=a[i][j]; 11 vis[i][j]=1; 12 13 dfs(i,j); 14 15 sum-=a[i][j]; 16 vis[i][j]=0; 17 } 18 } 19 return; 20 }
接下来才是真正的优化了。
仔细观察我们的代码,我们可以发现在上面这段代码( dfs()函数 )中,我们用了两个循环。
根据题意,我们可以将其循环次数减少。
但请看第3行的判断代码。 因为这个判断一定要做到最大行的下两行才能保证计算全部正确。
而我们希望第一层循环只循环 n次。那么我们要重新改进我们的判断方式。
哎呀呀!!.....这怎么改呀?能不改吗?我已经想不出来了呀!
—好吧好吧,既然你这么说,那我们先仔细看一看这句代码的作用。
我们用这句代码不断更新当前的最大值。
每当计算到最后,我们在原来的值与新的值中选取一个较大的。
我们在赋值前的判断是为了在每次计算的最后将最终结果拿出来比较。
那如果我们去掉这个判断,直接在每一步后进行比较赋值,这并不影响我们最终的结果。
所以,我们可以直接将这个判断删去。
接下来,就可以无后顾之忧地把第一层循环的次数限制到 n中把第二层循环限制到 m中。
接下来。第一层循环中,我们每次从当前搜索到的行开始,因为前面已经被全部搜过了。
所以我们一开始就赋值为 x。
那么对于第二层循环呢? —这个,我们是不能改滴。
因为在搜索过程中,挡上一层从前面移动后去时,当前行的前方的数又可以使用了。
所以,第二层循环我们一定从1开始循环至 m。
以下是参考程序:
1 // 2 #include <bits/stdc++.h> 3 using namespace std; 4 typedef long long ll; 5 #define ri register ll 6 7 ll t,n,m,ans,sum; 8 ll a[9][9]; 9 bool vis[9][9]; 10 11 void dfs(ll x,ll y) 12 { 13 ans=max(ans,sum); 14 for(ri i=x;i<=n;i++) 15 { 16 for(ri j=1;j<=m;j++) 17 { 18 if(vis[i][j]||vis[i-1][j-1]||vis[i-1][j]||vis[i-1][j+1]||vis[i][j-1]||vis[i][j+1]||vis[i+1][j-1]||vis[i+1][j]||vis[i+1][j+1])continue; 19 20 sum+=a[i][j]; 21 vis[i][j]=1; 22 23 dfs(i,j); 24 25 sum-=a[i][j]; 26 vis[i][j]=0; 27 } 28 } 29 return; 30 } 31 32 void work() 33 { 34 ans=0,sum=0; 35 memset(vis,0,sizeof(vis)); 36 memset(a,0,sizeof(a)); 37 cin>>n>>m; 38 // scanf("%lld%lld",&n,&m); 39 for(ri i=1;i<=n;i++) 40 { 41 for(ri j=1;j<=m;j++) 42 { 43 cin>>a[i][j]; 44 // scanf("%lld",&a[i][j]); 45 } 46 } 47 dfs(1,1); 48 cout<<ans<<'\n'; 49 // printf("%lld\n",ans); 50 } 51 52 signed main() 53 { 54 ios::sync_with_stdio(0),cin.tie(0); 55 cin>>t; 56 // scanf("%lld",&t); 57 while(t--) 58 { 59 work(); 60 } 61 return 0; 62 } 63 //