BZOJ4000 [TJOI2015]棋盘 【状压dp + 矩阵优化】
题目链接
题解
注意题目中的编号均从\(0\)开始= =
\(m\)特别小,考虑状压
设\(f[i][s]\)为第\(i\)行为\(s\)的方案数
每个棋子能攻击的只有本行,上一行,下一行,
我们能迅速找出哪些状态是合法的,以及每个状态所对应的上一行攻击位置的并和下一行攻击位置的并
如果两个状态上下相互攻击不到,就是合法的转移
我们弄一个\(2^m * 2^m\)的转移矩阵,就可以矩阵优化了
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define uint unsigned int
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 65,maxm = 100005,INF = 1000000000;
inline int read(){
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
int s1,s2,s3,n,m,p,k;
struct Matrix{
uint s[maxn][maxn]; int n,m;
Matrix(){memset(s,0,sizeof(s)); n = m = 0;}
}A,F,Fn;
inline Matrix operator *(const Matrix& a,const Matrix& b){
Matrix c;
if (a.m != b.n) return c;
c.n = a.n; c.m = b.m;
for (int i = 0; i < c.n; i++)
for (int j = 0; j < c.m; j++)
for (int k = 0; k < a.m; k++)
c.s[i][j] += a.s[i][k] * b.s[k][j];
return c;
}
inline Matrix qpow(Matrix a,int b){
Matrix ans; ans.n = ans.m = a.n;
for (int i = 0; i < ans.n; i++) ans.s[i][i] = 1;
for (; b; b >>= 1,a = a * a)
if (b & 1) ans = ans * a;
return ans;
}
int f[maxn];
bool check(int s){
for (int i = 0; i < m; i++){
if (s & (1 << i)){
if (i + 1 >= p - k){
if (s & (s2 << (i + 1 - (p - k)))) return false;
}
else if (s & (s2 >> ((p - k) - i - 1))) return false;
}
}
return true;
}
int getu(int s){
int re = 0;
for (int i = 0; i < m; i++){
if (s & (1 << i)){
if (i + 1 >= p - k) re |= s1 << (i + 1 - (p - k));
else re |= s1 >> ((p - k) - i - 1);
}
}
return re;
}
int getd(int s){
int re = 0;
for (int i = 0; i < m; i++){
if (s & (1 << i)){
if (i + 1 >= p - k) re |= s3 << (i + 1 - (p - k));
else re |= s3 >> ((p - k) - i - 1);
}
}
return re;
}
void print(int x){
for (int i = 4; i >= 0; i--)
printf("%d",(x & (1 << i)) != 0);
}
int main(){
n = read(); m = read();
p = read(); k = read();
REP(i,p) s1 = (s1 << 1) + read();
REP(i,p){
if (i == k + 1) s2 <<= 1,read();
else s2 = (s2 << 1) + read();
}
REP(i,p) s3 = (s3 << 1) + read();
int N = 1 << m;
F.n = N; F.m = 1;
for (int s = 0; s < N; s++)
if (check(s)) F.s[s][0] = 1;
A.n = A.m = N;
for (int s = 0; s < N; s++){
if (!check(s)) continue;
for (int e = 0; e < N; e++){
if (!check(e)) continue;
int u = getu(s),d = getd(e);
if (!(s & d) && !(e & u))
A.s[s][e] = 1;
}
}
Fn = qpow(A,n - 1) * F;
uint ans = 0;
for (int i = 0; i < N; i++) ans += Fn.s[i][0];
cout << ans << endl;
return 0;
}