bzoj2004 矩阵快速幂优化状压dp
https://www.lydsy.com/JudgeOnline/problem.php?id=2004
以前只会状压dp和矩阵快速幂dp,没想到一道题还能组合起来一起用,算法竞赛真是奥妙重重
小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距 离均为1km。 作为公交车线路的规划者,小Z调查了市民的需求,决定按下述规则设计线路: 1.设共K辆公交车,则1到K号站作为始发站,N-K+1到N号台作为终点站。 2.每个车站必须被一辆且仅一辆公交车经过(始发站和 终点站也算被经过)。 3.公交车只能从编号较小的站台驶往编号较大的站台。 4.一辆公交车经过的相邻两个 站台间距离不得超过Pkm。 在最终设计线路之前,小Z想知道有多少种满足要求的方案。由于答案可能很大,你只 需求出答案对30031取模的结果。
其实这不是一个难想的题目,P小于10的范围很容易就会想到去状态压缩,dp题的用意表达的也比较刻意,N的范围1e9又在含沙射影的告诉我这得矩乘优化。
这就得出了这题的大致算法,但是比较困难的事实上是怎么去方程转移怎么去优化。
和以前的套路一样,考虑先写一个朴素算法。看了题目很容易发现,在p公里的区间内,一定会出现K辆车停过的站牌,用dp[i][j]表示到了i这个位置,状态为j的数量个数。1表示这个位置有一辆车,0表示这个位置的车已经开到后面去了。
状态转移方程是每个状态考虑一个有车的位置开到i + 1这个位置的状态。由此我们可以推出一个朴素算法
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x); #define Pri(x) printf("%d\n", x) #define Prl(x) printf("%lld\n",x); #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; const double eps = 1e-9; const int maxn = 110; const int INF = 0x3f3f3f3f; const int mod = 30031; int N,M,K,P; int dp[2][1 << 10]; //在这个点之前p个位置的状态 int pre[1 << 10]; bool limit[1 << 10]; int usable[1 << 10],cnt; void init(){ For(i,0,(1 << P) - 1){ int num = 0; For(j,0,P - 1){ if(i & (1 << j)){ if(j == P - 1) limit[i] = 1; num++; } } if(num == K && (i & 1)) usable[++cnt] = i; } } int main() { Sca3(N,K,P); init(); int s = 0; For(i,0,K - 1) s |= (1 << i); dp[K & 1][s] = 1; For(i,K + 1,N){ Mem(dp[i & 1],0); For(j,1,cnt){ int t = usable[j]; if(limit[usable[j]]){ t ^= (1 << (P - 1)); t <<= 1; t++; dp[i & 1][t] = (dp[i & 1][t] + dp[i + 1 & 1][usable[j]]) % mod; }else{ t <<= 1; t++; For(k,1,P){ if(t & (1 << k)) dp[i & 1][t ^ (1 << k)] = (dp[i & 1][t ^ (1 << k)] + dp[i + 1 & 1][usable[j]]) % mod; } } } } Pri(dp[N & 1][s]); #ifdef VSCode system("pause"); #endif return 0; }
这是一个时间复杂度和空间复杂度双双爆炸的算法,空间我们可以用滚动数组来优化掉,但是1e9的N是无论如何也不可能去线性递推出来的。
当写完这样一个朴素算法的时候,快速矩阵幂就很容易写出来去优化了,
矩阵内数字的意义是这个状态下的数量个数,每次矩乘相当于到了下一个站牌状态数量的更新。
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x); #define Pri(x) printf("%d\n", x) #define Prl(x) printf("%lld\n",x); #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; const double eps = 1e-9; const int maxn = 110; const int INF = 0x3f3f3f3f; const int mod = 30031; int N,M,K,P; bool limit[1 << 11]; int id[1 << 11]; int usable[310],cnt; struct Mat{ LL a[210][210]; void init(){ Mem(a,0); } }; Mat operator *(Mat a,Mat b){ Mat ans; ans.init(); For(i,0,cnt){ For(j,0,cnt){ For(k,0,cnt){ ans.a[i][j] = (ans.a[i][j] + a.a[i][k] * b.a[k][j]) % mod; } } } return ans; } void init(){ cnt = -1; For(i,0,(1 << P) - 1){ int num = 0; For(j,0,P - 1) num += ((i & (1 << j)) != 0); limit[i] = (i & (1 << (P - 1))); if(num == K && (i & 1)){ usable[++cnt] = i; id[i] = cnt; } } } void solve(){ int s = (1 << K) - 1; Mat base,ans; ans.init(); base.init(); ans.a[0][id[s]] = 1; For(i,0,cnt){ int t = usable[i]; if(limit[t]){ t ^= (1 << (P - 1)); t <<= 1; t++; base.a[i][id[t]] = 1; }else{ t <<= 1; t++; For(k,1,P){ if(t & (1 << k)){ base.a[i][id[t ^ (1 << k)]] = 1; } } } } N -= K; while(N){ if(N & 1) ans = ans * base; base = base * base; N >>= 1; } Prl(ans.a[0][id[s]]); } int main() { Sca3(N,K,P); init(); solve(); #ifdef VSCode system("pause"); #endif return 0; }