题解 棋盘
发现可以矩阵优化转移
一次能跳两行,可以将矩阵开大一倍存一下上一行的信息
- 对于一类形如「对一些矩阵维护队列,要每次查询整个队列中的矩阵的乘积」的问题:
那么我们维护两个栈,每个维护一段区间 \([l, m], [m+1, r]\)
第一个栈的第 \(i\) 个元素,维护从第 \(i\) 行到第 \(m\) 行的转移
第二个栈的元素维护从 \(m\) 到 \(j\) 的转移
push:直接在第二个栈加入一个元素
pop:
1.如果第一个栈有元素,直接弹掉
2.如果第一个栈没有元素,则弹空第二个栈的元素,重构之后加入第一个栈
Que:合并两个栈栈顶的转移矩阵的情况即可
复杂度分析:每个元素第一次插入会在第二个栈,最多会被重构一次
因此复杂度为均摊 \(O((q+m)n^3)\)
然后发现因为矩阵开大了带了个8的常数,就过不去了
所以需要题解做法:
先咕了
补上了,但康了std无数遍
现在让每个矩阵只存储一行的信息
于是转移需要用两行来转移
考虑从 \(l\) 到 \(r\) 的所有转移路径,要么经过 \(mid\),要么恰好跳过 \(mid\)
于是维护两对栈,分别存 \([l, mid]\&[mid, r]\) 的走法数和 \([l, mid-1]\&[mid+1, r]\) 的
具体的,令 \(tran[i][j][k]\) 为从mid的第j列走到i的第k列的方案数
栈的部分和上面一样
特别注意一个细节,tran1存的区间较小,所以在处理转移的时候应要求 \(r-2>mid\) 而不是 \(r-2\geqslant mid\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 50010
#define ll long long
//#define int long long
int n, q;
const int mod=998244353;
inline void md(int& a, int b) {a+=b; a=a>=mod?a-mod:a;}
#if 0
namespace task1{
char s[N][21], t[N];
int op[N], ql[N], qr[N], qx[N], qy[N];
struct matrix{
int n, m;
int a[11][11];
matrix(){memset(a, 0, sizeof(a));}
matrix(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
void resize(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<a[i][j]<<' '; cout<<endl;}}
inline int* operator [] (int t) {return a[t];}
inline matrix operator * (matrix &b) {
matrix ans(n, b.m);
for (int i=1; i<=n; ++i)
for (int k=1; k<=m; ++k)
for (int j=1; j<=b.m; ++j)
ans[i][j]=(ans[i][j]+1ll*a[i][k]*b[k][j])%mod;
return ans;
}
}mat[N], dat[N<<2], v;
int tl[N<<2], tr[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define dat(p) dat[p]
#define pushup(p) dat(p)=dat(p<<1)*dat(p<<1|1)
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r;
if (l==r) {dat(p)=mat[l]; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
matrix query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) {return dat(p);}
int mid=(tl(p)+tr(p))>>1;
if (l<=mid&&r>mid) return query(p<<1, l, r)*query(p<<1|1, l, r);
else if (l<=mid) return query(p<<1, l, r);
else return query(p<<1|1, l, r);
}
void solve() {
// cout<<double(sizeof(dat)+sizeof(mat))/1000/1000<<endl; return ;
int l=1, r=0;
for (int i=1; i<=q; ++i) {
scanf("%s", t);
if (*t=='A') {
++r;
scanf("%s", s[r]+1);
mat[r].resize(n*2, n*2);
for (int j=1; j<=n; ++j) {
if (s[r][j]=='.') {
if (j-2>=1) mat[r][n+j-2][n+j]+=1;
if (j-1>=1) mat[r][j-1][n+j]+=1;
if (j+1<=n) mat[r][j+1][n+j]+=1;
if (j+2<=n) mat[r][n+j+2][n+j]+=1;
}
mat[r][n+j][j]+=1;
}
}
else if (*t=='D') ++l;
else {
op[i]=1; ql[i]=l; qr[i]=r;
scanf("%d%d", &qx[i], &qy[i]);
}
}
build(1, 1, r);
// v.resize(1, n*2);
// v[1][9]=1;
// v=v*mat[2]; v.put(); cout<<endl;
// v=v*mat[3]; v.put(); cout<<endl;
// mat[2].put(); cout<<endl;
// return ;
for (int i=1; i<=q; ++i) if (op[i]) {
if (ql[i]>qr[i]) puts("0");
else if (s[ql[i]][qx[i]]=='#' || s[qr[i]][qy[i]]=='#') puts("0");
else if (ql[i]==qr[i]) puts(qx[i]==qy[i]?"1":"0");
else {
// cout<<"q: "<<ql[i]<<' '<<qr[i]<<' '<<qx[i]<<' '<<qy[i]<<endl;
v.resize(1, n*2);
v[1][n+qx[i]]=1;
// v.put(); cout<<endl;
// query(1, ql[i]+1, qr[i]).put(); cout<<endl;
// mat[r].put(); cout<<endl;
// (v*query(1, ql[i]+1, qr[i])).put(); cout<<endl;
printf("%d\n", (v*query(1, ql[i]+1, qr[i]))[1][n+qy[i]]);
}
}
}
}
namespace task2{
char s[N][21], t[N];
int top1, top2;
struct matrix{
int n, m;
int a[41][41];
matrix(){/*memset(a, 0, sizeof(a));*/}
matrix(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
void resize(int x, int y){n=x; m=y; memset(a, 0, sizeof(a));}
void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<a[i][j]<<' '; cout<<endl;}}
inline int* operator [] (int t) {return a[t];}
inline matrix operator * (matrix &b) {
matrix ans(n, b.m);
for (int k=1; k<=m; ++k)
for (int j=1; j<=b.m; ++j) if (b[k][j])
for (int i=1; i<=n; ++i) if (a[i][k])
ans[i][j]=(ans[i][j]+1ll*a[i][k]*b[k][j])%mod;
return ans;
}
}s1[N], s2[N], s3[N], v;
void move() {
if (top1==0) {
s1[++top1]=s2[top2--];
while (top2) {
s1[top1+1]=s2[top2--]*s1[top1];
++top1;
}
}
}
void solve() {
// cout<<double(sizeof(s1)*3)/1000/1000<<endl; return ;
int l=1, r=0, x, y;
for (int i=1; i<=q; ++i) {
scanf("%s", t);
if (*t=='A') {
++r;
scanf("%s", s[r]+1);
v.resize(n*2, n*2);
for (int j=1; j<=n; ++j) {
if (s[r][j]=='.') {
if (j-2>=1) v[n+j-2][n+j]+=1;
if (j-1>=1) v[j-1][n+j]+=1;
if (j+1<=n) v[j+1][n+j]+=1;
if (j+2<=n) v[n+j+2][n+j]+=1;
}
v[n+j][j]+=1;
}
s2[++top2]=v;
if (top2>1) s3[top2]=s3[top2-1]*v;
else s3[top2]=v;
}
else if (*t=='D') {
++l;
move();
--top1;
}
else {
scanf("%d%d", &x, &y);
if (l>r) puts("0");
else if (s[l][x]=='#' || s[r][y]=='#') puts("0");
else if (l==r) puts(x==y?"1":"0");
else {
// cout<<"top: "<<top1<<' '<<top2<<endl;
v.resize(1, n*2);
v[1][n+x]=1;
move();
if (top1>1) v=v*s1[top1-1];
if (top2) v=v*s3[top2];
printf("%d\n", v[1][n+y]);
}
}
}
exit(0);
}
}
#endif
namespace task{
char s[N][21], t[N];
int tran[N][21][21], tran1[N][21][21];
int l=1, r=0, mid=0;
void gettran(int now, int lst, int llst) {
memset(tran[now], 0, sizeof(tran[now]));
for (int i=1; i<=n; ++i) if (s[now][i]=='.') {
for (int j=1; j<=n; ++j) {
if (i>1) md(tran[now][i][j], tran[llst][i-1][j]);
if (i<n) md(tran[now][i][j], tran[llst][i+1][j]);
if (i>2) md(tran[now][i][j], tran[lst][i-2][j]);
if (i<n-1) md(tran[now][i][j], tran[lst][i+2][j]);
}
}
}
void gettran1(int now, int lst, int llst) {
memset(tran1[now], 0, sizeof(tran1[now]));
for (int i=1; i<=n; ++i) if (s[now][i]=='.') {
for (int j=1; j<=n; ++j) {
if (i>1) md(tran1[now][i][j], tran1[llst][i-1][j]);
if (i<n) md(tran1[now][i][j], tran1[llst][i+1][j]);
if (i>2) md(tran1[now][i][j], tran1[lst][i-2][j]);
if (i<n-1) md(tran1[now][i][j], tran1[lst][i+2][j]);
}
}
}
void rebuild() {
mid=r;
memset(tran[r], 0, sizeof(tran[r]));
for (int i=1; i<=n; ++i) if (s[r][i]=='.') tran[r][i][i]=1;
for (int i=r-1; i>=l; --i) gettran(i, i+1, i+2<=r?i+2:0);
if (l+1<=r) {
memset(tran1[r-1], 0, sizeof(tran[r-1]));
for (int i=1; i<=n; ++i) if (s[r-1][i]=='.') tran1[r-1][i][i]=1;
for (int i=r-2; i>=l; --i) gettran1(i, i+1, i+2<r?i+2:0);
}
}
void put(int k) {
// cout<<"tran("<<k<<','<<to<<") "; for (int i=1; i<=n; ++i) cout<<tran[k][i][to]<<' '; cout<<endl;
cout<<"tran: "<<k<<endl;
for (int i=1; i<=n; ++i) {
for (int j=1; j<=n; ++j) cout<<tran[k][i][j]<<' '; cout<<endl;
}
}
void solve() {
for (int i=1,x,y; i<=q; ++i) {
scanf("%s", t);
if (*t=='A') {
scanf("%s", s[++r]+1);
if (l>mid) rebuild();
else {
gettran(r, r-1, r-2>=mid?r-2:0);
if (r==mid+1) {
memset(tran1[r], 0, sizeof(tran1[r]));
for (int j=1; j<=n; ++j) if (s[r][j]=='.') tran1[r][j][j]=1;
}
else gettran1(r, r-1, r-2>mid?r-2:0); // 这里脑残取等了 tran1是 [l, mid-1]&[mid+1, r],不能等于mid
}
}
else if (*t=='D') {
++l;
if (l>mid && l<=r) rebuild();
}
else {
scanf("%d%d", &x, &y);
if (l>r) puts("0");
else if (s[l][x]=='#' || s[r][y]=='#') puts("0");
else if (l==r) puts(x==y?"1":"0");
else {
// cout<<"lr: "<<l<<' '<<r<<endl;
ll ans=0;
for (int j=1; j<=n; ++j) ans=(ans+1ll*tran[l][x][j]*tran[r][y][j])%mod;
if (l<mid && r>mid) {
for (int j=1; j<=n; ++j) {
if (j>1) ans=(ans+1ll*tran1[l][x][j-1]*tran1[r][y][j])%mod;
if (j<n) ans=(ans+1ll*tran1[l][x][j+1]*tran1[r][y][j])%mod;
}
}
printf("%lld\n", ans);
}
}
}
exit(0);
}
}
signed main()
{
freopen("chess.in", "r", stdin);
freopen("chess.out", "w", stdout);
scanf("%d%d", &n, &q);
task::solve();
return 0;
}