BZOJ2004: [Hnoi2010]Bus 公交线路
Description
小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距离均为1km。
作为公交车线路的规划者,小Z调查了市民的需求,决定按下述规则设计线路:
1.设共K辆公交车,则1到K号站作为始发站,N-K+1到N号台作为终点站。
2.每个车站必须被一辆且仅一辆公交车经过(始发站和终点站也算被经过)。
3.公交车只能从编号较小的站台驶往编号较大的站台。
4.一辆公交车经过的相邻两个站台间距离不得超过Pkm。
在最终设计线路之前,小Z想知道有多少种满足要求的方案。
由于答案可能很大,你只需求出答案对30031取模的结果。
Input
仅一行包含三个正整数N K P,分别表示公交车站数,公交车数,相邻站台的距离限制。
N<=10^9,1<P<=10,K<N,1<K<=P
Output
仅包含一个整数,表示满足要求的方案数对30031取模的结果。
Sample Input
样例一:10 3 3
样例二:5 2 3
样例三:10 2 4
样例二:5 2 3
样例三:10 2 4
Sample Output
1
3
81
3
81
HINT
【样例说明】
样例一的可行方案如下: (1,4,7,10),(2,5,8),(3,6,9)
样例二的可行方案如下: (1,3,5),(2,4) (1,3,4),(2,5) (1,4),(2,3,5)
P<=10 , K <=8
题解Here!
看到P<=10,立马想到状压DP。
然后本蒟蒻就不会了,还是太菜了。。。
注意看这样一种路线:
A B C _ _->_ B C A _->_ B _ A C
A B C _ _->A B _ _ C->_ B _ A C
虽然顺序不同,但是他们是同一种方案,都是A由1到4,C由3到5。
所以我们不妨强制要求必须得最靠前的先走。
这样一来就可以转移了。
一个P位的二进制位,恰好有k个1且最高位为1表示状态。
我刚开始不明白为什么恰好有k个,也不明白为什么最高位为1,想到那个强制要求之后就懂了。
所以合法状态最多有C94=126。
但是那个N<=109怎么解?
时间复杂度最高也只有log2N×C94。
注:矩阵乘法真是个玄学的东东。。。
附代码:
#include<iostream> #include<algorithm> #include<cstdio> #include<cstring> #define MAXN 210 #define MOD 30031 using namespace std; int n,m=0,k,p; int bit[20],val[MAXN]; struct Matrix{ long long a[MAXN][MAXN]; }base,ans; inline int read(){ int date=0,w=1;char c=0; while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();} while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();} return date*w; } Matrix operator *(Matrix x,Matrix y){ Matrix ret; for(int i=1;i<=m;i++) for(int j=1;j<=m;j++){ ret.a[i][j]=0; for(int k=1;k<=m;k++){ ret.a[i][j]+=x.a[i][k]*y.a[k][j]%MOD;; ret.a[i][j]%=MOD; } } return ret; } Matrix mexp(int b){ Matrix s; for(int i=1;i<=m;i++)s.a[i][i]=1; while(b){ if(b&1)s=s*base; base=base*base; b>>=1; } return s; } inline int lowbit(int x){return x&(-x);} void dfs(int x,int s,int v){ if(s==k){ val[++m]=v; return; } for(int i=x-1;i;i--)dfs(i,s+1,v+bit[i-1]); } void work(){ ans.a[1][1]=1; base=mexp(n-k); ans=ans*base; printf("%lld\n",ans.a[1][1]); } void init(){ n=read();k=read();p=read(); bit[0]=1; for(int i=1;i<=19;i++)bit[i]=bit[i-1]<<1; dfs(p,1,bit[p-1]); for(int i=1;i<=m;i++) for(int j=1;j<=m;j++){ int x=(val[i]<<1)^bit[p]^val[j]; if(x==lowbit(x))base.a[i][j]=1; } } int main(){ init(); work(); return 0; }