AT2371 [AGC013E] Placing Squares
https://www.luogu.com.cn/problem/AT2371
首先把转移方程列出来
f
[
i
]
=
∑
j
=
0
i
−
1
f
[
j
]
(
i
−
j
)
2
f[i]=\sum_{j=0}^{i-1}f[j](i-j)^2
f[i]=j=0∑i−1f[j](i−j)2
对于
i
i
i为关键点的
f
[
i
]
=
0
f[i]=0
f[i]=0
然后再考虑
i
+
1
i+1
i+1后的转移
f [ i + 1 ] = ∑ j = 0 i f [ j ] ( i + 1 − j ) 2 = ∑ ( i − j ) 2 f [ j ] + 2 ∑ ( i − j ) f [ j ] + ∑ f [ j ] f[i+1]=\sum_{j=0}^{i}f[j](i+1-j)^2 \\= \sum(i-j)^2f[j]+2\sum(i-j)f[j]+\sum f[j] f[i+1]=∑j=0if[j](i+1−j)2=∑(i−j)2f[j]+2∑(i−j)f[j]+∑f[j]
设 a [ i ] = ∑ ( i − j ) 2 f [ j ] b [ i ] = ∑ ( i − j ) f [ j ] c [ i ] = ∑ f [ j ] 设a[i]=\sum(i-j)^2f[j]\\b[i]=\sum(i-j)f[j]\\c[i]=\sum f[j] 设a[i]=∑(i−j)2f[j]b[i]=∑(i−j)f[j]c[i]=∑f[j]
容易发现 f [ i + 1 ] = a [ i ] f[i+1]=a[i] f[i+1]=a[i]
然后再看 a , b , c a,b,c a,b,c的转移
c [ i + 1 ] = c [ i ] + f [ i + 1 ] = c [ i ] + f [ i + 1 ] b [ i + 1 ] = b [ i ] + c [ i ] + f [ i + 1 ] c [ i + 1 ] = a [ i ] + 2 b [ i ] + c [ i ] + f [ i + 1 ] c[i+1]=c[i]+f[i+1]=c[i]+f[i+1]\\b[i+1]=b[i]+c[i]+f[i+1]\\c[i+1]=a[i]+2b[i]+c[i]+f[i+1] c[i+1]=c[i]+f[i+1]=c[i]+f[i+1]b[i+1]=b[i]+c[i]+f[i+1]c[i+1]=a[i]+2b[i]+c[i]+f[i+1]
然后 f [ i + 1 ] = a [ i ] f[i+1]=a[i] f[i+1]=a[i],替换进去就可以得到转移矩阵了
对于 i + 1 i+1 i+1为关键点的,把每项后面的 f [ i + 1 ] f[i+1] f[i+1]扔掉就行了
写成矩阵大概是
A
=
[
1
0
0
2
1
0
1
1
1
]
A=\begin{bmatrix} 1 & 0 & 0 \\ 2 & 1 & 0 \\ 1 & 1 & 1 \end{bmatrix}
A=⎣⎡121011001⎦⎤
B
=
[
2
1
1
2
1
0
1
1
1
]
B=\begin{bmatrix} 2 & 1 & 1 \\ 2 & 1 & 0 \\ 1 & 1 & 1 \end{bmatrix}
B=⎣⎡221111101⎦⎤
[
a
,
b
,
c
]
∗
B
[a,b,c] * B
[a,b,c]∗B
然后对于关键点
[
a
,
b
,
c
]
∗
A
[a,b,c]*A
[a,b,c]∗A
关键点之间的用矩阵快速幂加速即可
code:
#include<bits/stdc++.h>
#define N 400050
#define mod 1000000007
using namespace std;
struct MT {
int a[3][3];
void init() {
for(int i = 0; i < 3; i ++)
for(int j = 0; j < 3; j ++) a[i][j] = (i == j);
}
MT operator * (const MT& o) const {
MT c;
for(int i = 0; i < 3; i ++)
for(int j = 0; j < 3; j ++) {
c.a[i][j] = 0;
for(int k = 0; k < 3; k ++)
c.a[i][j] = (c.a[i][j] + 1ll * a[i][k] * o.a[k][j] % mod) % mod;
}
return c;
}
};
MT qpow(MT x, int y) {
MT ret; ret.init();
for(; y; y >>= 1, x = x * x) if(y & 1) ret = ret * x;
return ret;
}
int n, m, x[N];
int main() {
MT ans, a, b;
ans.init();
a.a[0][0] = 1, a.a[0][1] = 0, a.a[0][2] = 0;
a.a[1][0] = 2, a.a[1][1] = 1, a.a[1][2] = 0;
a.a[2][0] = 1, a.a[2][1] = 1, a.a[2][2] = 1;
b.a[0][0] = 2, b.a[0][1] = 1, b.a[0][2] = 1;
b.a[1][0] = 2, b.a[1][1] = 1, b.a[1][2] = 0;
b.a[2][0] = 1, b.a[2][1] = 1, b.a[2][2] = 1;
x[0] = -1;
scanf("%d%d", &n, &m);
for(int i = 1; i <= m; i ++) {
scanf("%d", &x[i]);
ans = ans * qpow(b, x[i] - x[i - 1] - 1) * a;
}
ans = ans * qpow(b, n - x[m] - 1);
printf("%d", ans.a[2][0]);
return 0;
}