状压DP

1.郊区春游
\(dp_{st,i}\)表示当前状态为st,最后一个点为i的最短距离
\(dp_{st,i}=min_{st',j} (dp_{st',j}+dis_{j,i})\)

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e5+101;
const int MOD=998244353;
const ll inf=2147383647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,m,r,a[maxn];
int dis[201][201],dp[maxn][16];
int main(){
    memset(dis,0x3f3f3f3f,sizeof(dis));
    memset(dp, 0x3f3f3f3f,sizeof(dp));
    n=read();m=read();r=read();
    for(int i=1;i<=r;i++)a[i]=read();
    for(int i=1;i<=m;i++){
        int x=read(),y=read(),z=read();
        dis[x][y]=dis[y][x]=z;
    }
    for(int i=1;i<=n;i++)dis[i][i]=0;
    for(int k=1;k<=n;k++){
        for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){
            dis[i][j]=min(dis[i][j],dis[i][k]+dis[k][j]);
        }
    }
    for(int i=1;i<=r;i++)dp[1<<(i-1)][i]=0;
    for(int i=1;i<(1<<r);i++)for(int j=1;j<=r;j++){
        if((i&(1<<(j-1)))==0)continue;
        for(int k=1;k<=r;k++){
            if(i&(1<<(k-1)))continue;
            dp[i|(1<<(k-1))][k]=min(dp[i|(1<<(k-1))][k],dp[i][j]+dis[a[j]][a[k]]);
        }

    }
    int ans=inf;
    for(int i=1;i<=r;i++)ans=min(ans,dp[(1<<r)-1][i]);
    cout<<ans;
    return 0;
}

2.德玛西亚万岁

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e5+101;
const int MOD=100000000;
const ll inf=2147383647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,m;
ll dp[21][10001];
void solve(){
    vector<vector<int> >bk(n+2);
    for(int i=1;i<=n;i++){
        int now=0;
        for(int j=m-1;j>=0;j--)now+=(read()<<j);
        now=now^((1<<m)-1);
        for(int j=0;j<(1<<m);j++){
            if((now&j) || j&(j<<1))continue;
            bk[i].pb(j);
        }
        if(bk[i].size()==0)bk[i].pb(0);
    }
    bk[0].pb(0);dp[0][0]=1;
    bk[n+1].pb(0);
    for(int i=1;i<=n+1;i++){
        for(auto j:bk[i]){
            dp[i][j]=0;
            for(auto k:bk[i-1]){
                if(k&j)continue;
                dp[i][j]+=dp[i-1][k];
                dp[i][j]%=MOD;
            }
        }
    }
    cout<<(dp[n+1][0]%MOD+MOD)%MOD<<endl;
    return ;
}
int main(){
    while(scanf("%d%d",&n,&m)!=EOF)solve();
    return 0;
}

3.多彩的树
对于一个n个点的联通块,一共有\(C_n^2\)条路径
因为k很小,颜色选择的种类共有\(2^k -1\)
枚举选的颜色状态为st,设\(f[st]\)为选择颜色状态为st的路径方案数
注意!我们这里不是统计正好颜色数就是这么多的路径数!
比如我们枚举的颜色为 1 2 3
那么只包含颜色 1 2 的路径也在我们的统计范围!!因为要找准确的一条路径是十分困难的

我们可以将属于这个颜色的节点先筛选出来
然后看他们形成几个联通块
之后通过开头的那个性质,我们直接得到了答案
最后,我们再利用容斥原理对最终答案景行求解

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e5+101;
const int MOD=1e9+7;
const ll inf=2147383647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}
int get(int x){
    int cnt=0;
    while(x){
        if(x&1)cnt++;
        x>>=1;
    }
    return cnt;
}
int n,k,a[maxn];
int tot,head[maxn],to[maxn],nx[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
int vis[maxn];
int dfs(int x,int st){
    vis[x]=1;
    if((st&a[x])==0)return 0;
    int ans=1;
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(vis[v])continue;
        ans+=dfs(v,st);
    }
    return ans;
}
ll f[maxn];
int main(){
    n=read();k=read();
    for(int i=1;i<=n;i++)a[i]=1<<(read()-1);
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    for(int i=1;i<(1<<k);i++){
        memset(vis,0,sizeof(int)*(n+1));
        for(int u=1;u<=n;u++){
            if(vis[u])continue;
            ll cnt=dfs(u,i);
            f[i]=f[i]+cnt+cnt*(cnt-1)*power(2,MOD-2)%MOD;
            f[i]%=MOD;
        }
    }
    ll ans=0;
    for(int i=1;i<(1<<k);i++){
        for(int j=i&(i-1);j;j=(j-1)&i)f[i]-=f[j]; //这句话等同于下面的for循环
        /*
        for(int j=1;j<i;j++){
            if((j&i)!=j)continue;
             //比如i的二进制状态为11010
            //枚举对应j,j=00010,01000,10000,11000,10010,01010
            //也就是枚举i包含的情况,枚举i的子集
            f[i]-=f[j];
        }
        */
        ans=ans+f[i]*power(131,get(i))%MOD;
        ans%=MOD;
    }
    cout<<(ans%MOD+MOD)%MOD;
    return 0;
}

4. [NOIP2017]宝藏
题目的意思就是找一棵生成树,使得代价和最小。
考虑在任意时刻,我们关心的只有我们已经把多少点加进生成树了,以及生成树的最大树高是多少。
那么我们就设\(dp_{s,i}\)为当前生成树已经包含集合s中的点,并且树高是i。
那么状态转移为:(设ss=s^s0)
\(dp_{s,i}=min_{s0\in s} (dp_{s0,i-1}+pay)\),其中s0是s的子集且一定能通过ss连边到s0的,pay为连完边的代价
\(pay=dis[s0][s-s0]*(i-1)\),\(dis[x][y]\)表示从x集合连边到y集合的最短路
考虑这么做为什么是对的。
但这样写方程难免会有一个疑问?(设ss=s^s0)
ss连向s0的最短路一定在i层吗?答案是不一定,但是不是最优的(因为有些点深度乘多了),所以一定能通过枚举使得存在一个ss在i层,且最优的。

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e5+101;
const int MOD=20020219;
const ll inf=2147383647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}                                                                    

int n,m,dp[5001][13],dis[5001][5001],l[13][13];                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
int main(){ 
    n=read();m=read();
    memset(l,0x3f3f3f3f,sizeof(l));
    memset(dp,0x3f3f3f3f,sizeof(dp));
    memset(dis,0x3f3f3f3f,sizeof(dp));
    for(int i=1;i<=m;i++){
        int x=read(),y=read(),z=read();
        l[x][y]=l[y][x]=min(l[x][y],z);
    }
    for(int s=0;s<(1<<n);s++)for(int x=s;x;x=(x-1)&s){
        //dis[x][y]集合x中每个点连出一条边到集合y的最短距离
        int y=s^x;
        vector<int>a,b;
        vector<int>d(n+1);
        for(int i=0;i<n;i++){
            d[i+1]=l[0][0];
            if(x&(1<<i))a.pb(i+1);
            if(y&(1<<i))b.pb(i+1);
        }
        //dis[0][s]的情况得提前预处理出来,s集合个数必须是1个,因为根只有一个
        if(y==0 && a.size()==1){dis[y][x]=0;continue;}
        else if(y==0)continue;
        for(auto yy: b)for(auto xx:a){
            d[yy]=min(d[yy],l[xx][yy]);
        }
        dis[x][y]=0;
        for(auto yy:b){
            if(d[yy]==l[0][0]){dis[x][y]=l[0][0];break;}
            dis[x][y]+=d[yy];
        }
    }
    dp[0][0]=0;
    int ans=inf;
    for(int s=0;s<(1<<n);s++)for(int i=1;i<=n;i++){
        for(int s0=s;s0;s0=(s0-1)&s){
            int ss=s0^s;
            if(dis[ss][s0]==l[0][0])continue;
            dp[s][i]=min(dp[s][i],dp[ss][i-1]+dis[ss][s0]*(i-1));
        }
        if(s==(1<<n)-1)ans=min(ans,dp[s][i]);
    }
    printf("%d\n",ans);
    return 0;
}
posted @ 2022-08-04 13:48  I_N_V  阅读(23)  评论(0编辑  收藏  举报