bzoj2004 矩阵快速幂优化状压dp

https://www.lydsy.com/JudgeOnline/problem.php?id=2004

以前只会状压dp和矩阵快速幂dp,没想到一道题还能组合起来一起用,算法竞赛真是奥妙重重

小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距
离均为1km。 作为公交车线路的规划者,小Z调查了市民的需求,决定按下述规则设计线路:
1.设共K辆公交车,则1到K号站作为始发站,N-K+1到N号台作为终点站。
2.每个车站必须被一辆且仅一辆公交车经过(始发站和
终点站也算被经过)。 
3.公交车只能从编号较小的站台驶往编号较大的站台。 
4.一辆公交车经过的相邻两个
站台间距离不得超过Pkm。 在最终设计线路之前,小Z想知道有多少种满足要求的方案。由于答案可能很大,你只
需求出答案对30031取模的结果。
题意

 

其实这不是一个难想的题目,P小于10的范围很容易就会想到去状态压缩,dp题的用意表达的也比较刻意,N的范围1e9又在含沙射影的告诉我这得矩乘优化。

这就得出了这题的大致算法,但是比较困难的事实上是怎么去方程转移怎么去优化。

和以前的套路一样,考虑先写一个朴素算法。看了题目很容易发现,在p公里的区间内,一定会出现K辆车停过的站牌,用dp[i][j]表示到了i这个位置,状态为j的数量个数。1表示这个位置有一辆车,0表示这个位置的车已经开到后面去了。

状态转移方程是每个状态考虑一个有车的位置开到i + 1这个位置的状态。由此我们可以推出一个朴素算法

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)  
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))  
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x);  
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);  
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long  
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second 
typedef vector<int> VI;
const double eps = 1e-9;
const int maxn = 110;
const int INF = 0x3f3f3f3f;
const int mod = 30031; 
int N,M,K,P;
int dp[2][1 << 10]; //在这个点之前p个位置的状态
int pre[1 << 10];
bool limit[1 << 10];
int usable[1 << 10],cnt;
void init(){
    For(i,0,(1 << P) - 1){
        int num = 0;
        For(j,0,P - 1){
            if(i & (1 << j)){
                if(j == P - 1) limit[i] = 1;
                num++;
            }
        }
        if(num == K && (i & 1)) usable[++cnt] = i;
    }
}
int main()
{
    Sca3(N,K,P);
    init();
    int s = 0;
    For(i,0,K - 1) s |= (1 << i);
    dp[K & 1][s] = 1;
    For(i,K + 1,N){
        Mem(dp[i & 1],0);
        For(j,1,cnt){
            int t = usable[j];        
            if(limit[usable[j]]){
                t ^= (1 << (P - 1));
                t <<= 1; t++;
                dp[i & 1][t] = (dp[i & 1][t] + dp[i + 1 & 1][usable[j]]) % mod;
            }else{
                t <<= 1; t++;
                For(k,1,P){
                    if(t & (1 << k)) dp[i & 1][t ^ (1 << k)] = (dp[i & 1][t ^ (1 << k)] + dp[i + 1 & 1][usable[j]]) % mod;
                }
            }
        }
    }
    Pri(dp[N & 1][s]);
    #ifdef VSCode
    system("pause");
    #endif
    return 0;
}
憨厚老实朴素算法

这是一个时间复杂度和空间复杂度双双爆炸的算法,空间我们可以用滚动数组来优化掉,但是1e9的N是无论如何也不可能去线性递推出来的。

当写完这样一个朴素算法的时候,快速矩阵幂就很容易写出来去优化了,

矩阵内数字的意义是这个状态下的数量个数,每次矩乘相当于到了下一个站牌状态数量的更新。

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)  
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))  
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x);  
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);  
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long  
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second 
typedef vector<int> VI;
const double eps = 1e-9;
const int maxn = 110;
const int INF = 0x3f3f3f3f;
const int mod = 30031; 
int N,M,K,P;
bool limit[1 << 11];
int id[1 << 11];
int usable[310],cnt;
struct Mat{
    LL a[210][210];
    void init(){
        Mem(a,0);
    }
};
Mat operator *(Mat a,Mat b){
    Mat ans; ans.init();
    For(i,0,cnt){
        For(j,0,cnt){
            For(k,0,cnt){
                ans.a[i][j] = (ans.a[i][j] + a.a[i][k] * b.a[k][j]) % mod;
            }
        }
    }
    return ans;
}
void init(){
    cnt = -1;
    For(i,0,(1 << P) - 1){
        int num = 0;
        For(j,0,P - 1) num += ((i & (1 << j)) != 0);
        limit[i] = (i & (1 << (P - 1)));
        if(num == K && (i & 1)){
            usable[++cnt] = i;
            id[i] = cnt;
        } 
    }
}
void solve(){
    int s = (1 << K) - 1;
    Mat base,ans;
    ans.init(); base.init();
    ans.a[0][id[s]] = 1;
    For(i,0,cnt){
        int t = usable[i];
        if(limit[t]){
            t ^= (1 << (P - 1)); t <<= 1; t++;
            base.a[i][id[t]] = 1;
        }else{
            t <<= 1; t++;
            For(k,1,P){
                if(t & (1 << k)){
                    base.a[i][id[t ^ (1 << k)]] = 1;
                }
            }
        }
    }
    N -= K;
    while(N){
        if(N & 1) ans = ans * base;
        base = base * base;
        N >>= 1;
    }
    Prl(ans.a[0][id[s]]);
}
int main()
{
    Sca3(N,K,P);
    init();
    solve();
    #ifdef VSCode
    system("pause");
    #endif
    return 0;
}

 

posted @ 2018-10-01 20:40  Hugh_Locke  阅读(357)  评论(0编辑  收藏  举报