dfs解矩阵问题+记忆化搜索

地宫取宝

https://www.acwing.com/problem/content/1214/

1.纯dfs,超时得一半分

#include<bits/stdc++.h>
using namespace std;
const int MOD=1e9+7;

int q[50][50];
int n,m,k;
long long int ans;
//param: x,y express position
//max express maxvalue in mybag
//u express numbers i took
void dfs(int x,int y,int max,int u){
    //base case;
    if(x==n || y==m)    return ; //fail
    int cur=q[x][y];
    if(x==n-1 && y==m-1){
        if(u==k)    ans++;     //success and take q[n-1][m-1]
        if(u==k-1 && cur>max) ans++; //success and dont take q[n-1][m-1]
    }
    //take
    if(cur>max){
        dfs(x+1,y,cur,u+1);
        dfs(x,y+1,cur,u+1);
    }
    //no take
    dfs(x+1,y,max,u);
    dfs(x,y+1,max,u);
}
int main(){
    cin>>n>>m>>k;
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++){
            scanf("%d",&q[i][j]);
        }
    }
    dfs(0,0,-1,0);
    cout<<ans%MOD;
    return 0;
}

2.记忆化搜索

我们发现每个格子都会存在重复参数的遍历情况,于是我们采用记忆化搜索来降低时间复杂度。

只需开一个多维数组cache(维度由dfs的参数数量决定),里面保存以此相同参数的dfs的结果(通常是题目所求),只需在原dfs的代码上修改开头结尾,以及返回值根据题设进行修改;

修改开头:通常增设一个base case,先查询此参数的dfs是否存在cache中,存在则直接return cache

修改结尾:每次dfs结尾必须给cache赋值以表示存入此状态,通常是题设所求的返回值参数

#include<bits/stdc++.h>
using namespace std;
const int MOD=1e9+7;

int q[51][51];
int n,m,k;
long long cache[51][51][14][13];
//param: x,y express position
//max express maxvalue in mybag
//u express numbers i took
long long dfs(int x,int y,int max,int u){
    if(cache[x][y][max+1][u]!=-1) 
        return cache[x][y][max+1][u]; //memory search
    long long  ans=0;
    //base case;
    if(x==n || y==m ||u>k )    return 0; //fail
    int cur=q[x][y];
    if(x==n-1 && y==m-1){
        if(u==k)    ans++;     //success and take q[n-1][m-1]
        if(u==k-1 && cur>max) ans++; //success and dont take q[n-1][m-1]
        ans%=MOD;
        return ans;
    }
    //take
    if(cur>max){
       ans+= dfs(x+1,y,cur,u+1);
       ans+= dfs(x,y+1,cur,u+1);
    }
    //no take
    ans+= dfs(x+1,y,max,u);
    ans+= dfs(x,y+1,max,u);
    cache[x][y][max+1][u]=ans%MOD;
    return cache[x][y][max+1][u];
}


int main(){
    cin>>n>>m>>k;
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++){
            scanf("%d",&q[i][j]);
        }
    }
    memset(cache,-1,sizeof(cache));
    printf("%lld",dfs(0,0,-1,0));
    

    return 0;
}
posted @ 2022-02-07 22:02  秋月桐  阅读(36)  评论(0编辑  收藏  举报