[bzoj4348]ParkII【插头dp】

【题目链接】
  https://www.lydsy.com/JudgeOnline/problem.php?id=2310
【题解】
  插头dp中一道较为繁琐(?)的题。
  多开一维状态记选了多少个度为1的点。同时在状态的括号序列中新添一种状态3表示单独的插头,接下来就是分类讨论了。细节有点多,但是思维难度不大,具体实现见代码。
  时间复杂度O(NM2(N+2)2)
  

/* --------------
    user Vanisher
    problem bzoj-2310 
----------------*/
# include <bits/stdc++.h>
# define    ll      long long
# define    inf     0x3f3f3f3f
# define    N       110
# define    T       10
using namespace std;
int read(){
    int tmp=0, fh=1; char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') fh=-1; ch=getchar();}
    while (ch>='0'&&ch<='9'){tmp=tmp*10+ch-'0'; ch=getchar();}
    return tmp*fh;
}
int now[T],nex[T],n,m,mp[N][N],ne[N][N],f1,f2,lim,h[1<<(T*2)],ans;
vector <int> g[2],f[2];
int findopp(int id){
    if (now[id]==1){
        int p=id+1, cnt=0;
        while (cnt!=0||now[p]!=2){
            if (now[p]==1) cnt++;
            if (now[p]==2) cnt--;
            p++;
        }
        return p;
    }
    if (now[id]==2){
        int p=id-1, cnt=0;
        while (cnt!=0||now[p]!=1){
            if (now[p]==1) cnt++;
            if (now[p]==2) cnt--;
            p--;
        }
        return p;
    }
    return -1;
}
int getnum(){
    int num=0;
    for (int i=0; i<=m+1; i++) 
        num=(nex[i]<<(i*2))+num;
    return num;
}
void join(int num){
    if (nex[0]>2) return;
    int tmp=getnum(),cnt=0;
    for (int i=1; i<=m+1; i++){
        if (nex[i]==1) cnt++;
        if (nex[i]==2) cnt--;
    }
    if (h[tmp]==-1){
        h[tmp]=g[f2].size();
        g[f2].push_back(tmp);
        f[f2].push_back(num);
    }
    else f[f2][h[tmp]]=max(f[f2][h[tmp]],num);
}
int main(){
    n=read(), m=read();
    for (int i=1; i<=n; i++)
        for (int j=1; j<=m; j++)
            mp[i][j]=read();
    if (n<m){
        for (int i=1; i<=n; i++)
            for (int j=1; j<=m; j++)
                ne[j][i]=mp[i][j];
        swap(n,m);
        for (int i=1; i<=n; i++)
            for (int j=1; j<=m; j++)
                mp[i][j]=ne[i][j];
    }
    f1=0, f2=1, lim=1<<((m+2)*2);
    g[f1].push_back(0);
    f[f1].push_back(0);
    memset(h,-1,sizeof(h));
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            for (unsigned k=0; k<g[f1].size(); k++){
                int tmp=g[f1][k], num=f[f1][k], ths=mp[i][j];
                if (tmp>=lim) continue;
                for (int t=0; t<=m+1; t++) now[t]=nex[t]=(tmp>>(t*2))&3;
                if (now[j]==3){
                    if (now[j+1]==0){
                        nex[0]=now[0]+0, nex[j]=3, nex[j+1]=0; join(num+ths);
                        nex[0]=now[0]+0, nex[j]=0, nex[j+1]=3; join(num+ths);
                        nex[0]=now[0]+1, nex[j]=0, nex[j+1]=0; join(num+ths); 
                    }
                    else if (now[j+1]==3){
                        nex[0]=now[0]+0, nex[j]=0, nex[j+1]=0; join(num+ths);
                    }
                    else {
                        int k=findopp(j+1);
                        if (k!=-1){nex[j]=0; nex[j+1]=0; nex[k]=3; join(num+ths);}
                    }
                    continue;
                }
                if (now[j+1]==3){
                    if (now[j]==0){
                        nex[0]=now[0]+0, nex[j]=3, nex[j+1]=0; join(num+ths);
                        nex[0]=now[0]+0, nex[j]=0, nex[j+1]=3; join(num+ths);
                        nex[0]=now[0]+1, nex[j]=0, nex[j+1]=0; join(num+ths); 
                    }
                    else if (now[j]==3){
                        nex[0]=now[0]+0, nex[j]=0, nex[j+1]=0; join(num+ths);
                    }
                    else {
                        int k=findopp(j);
                        if (k!=-1){nex[j]=0; nex[j+1]=0; nex[k]=3; join(num+ths);}
                    }
                    continue;
                }
                if (now[j]!=0&&now[j+1]!=0){
                    if (now[j]==2&&now[j+1]==1){
                        nex[j]=0, nex[j+1]=0; join(num+ths);
                    }
                    if (now[j]==1&&now[j+1]==1){
                        int p=findopp(j+1);  nex[j]=0, nex[j+1]=0, nex[p]=1; join(num+ths);
                    }
                    if (now[j]==2&&now[j+1]==2){
                        int p=findopp(j);  nex[j]=0, nex[j+1]=0, nex[p]=2; join(num+ths);
                    }
                    continue;
                }
                if (now[j]==0&&now[j+1]==0){
                    nex[j]=0, nex[j+1]=0; join(num);
                    nex[j]=1, nex[j+1]=2; join(num+ths);
                    nex[0]=now[0]+1, nex[j]=3, nex[j+1]=0; join(num+ths);
                    nex[0]=now[0]+1, nex[j+1]=3, nex[j]=0; join(num+ths); 
                    nex[0]=now[0]+2, nex[j]=0, nex[j+1]=0; join(num+ths);
                    continue;
                }
                if (now[j]!=0){
                    nex[j]=0, nex[j+1]=now[j]; join(num+ths);
                    nex[j]=now[j], nex[j+1]=0; join(num+ths);
                    int p=findopp(j);
                    nex[0]=now[0]+1, nex[j]=0, nex[j+1]=0; nex[p]=3; join(num+ths);
                    continue;
                }
                if (now[j+1]!=0){
                    nex[j]=0, nex[j+1]=now[j+1]; join(num+ths);
                    nex[j]=now[j+1], nex[j+1]=0; join(num+ths);
                    int p=findopp(j+1); 
                    nex[0]=now[0]+1, nex[j]=0, nex[j+1]=0; nex[p]=3; join(num+ths);
                    continue;
                }
            }
            for (unsigned k=0; k<g[f2].size(); k++) h[g[f2][k]]=-1;
            f[f1].clear(), g[f1].clear();
            swap(f1,f2);
        }
        for (unsigned k=0; k<g[f1].size(); k++){
            int tmp=g[f1][k]&3; g[f1][k]=(g[f1][k]>>2)<<2;
            g[f1][k]=(g[f1][k]<<2)+tmp;
        }
    }
    for (unsigned i=0; i<g[f1].size(); i++)
        if (g[f1][i]==2){
            ans=f[f1][i];
            break;
        }
    printf("%d\n",ans);
    return 0;
}
posted @ 2018-04-18 22:10  Vanisher  阅读(106)  评论(0编辑  收藏  举报