[AGC026D]Histogram Coloring
计算一个区间 [ l , r ] [l, r] [l,r] 的贡献,方法就是找到这个区间内 h h h 最小的几个点,然后从这几个点裂开计算贡献。
例如上图,我们就把整个问题的求解拆分成只考虑绿色部分的求解再利用几个绿色部分的信息整合到整体的信息。
状态:
f ( p , 0 ) f (p, 0) f(p,0) 表示只考虑区间 p p p,且 p p p 最后一行为红蓝交错的满足要求的方案数
f ( p , 1 ) f (p, 1) f(p,1) 表示只考虑区间 p p p 的满足要求的方案数
g ( p ) g (p) g(p) 表示只考虑区间 p p p,且 p p p 最后一行不为红蓝交错的满足要求的方案数
记 p p p 一次分裂后产生的区间的集合为 S S S
m i = min i ∈ p h [ i ] mi = \min_{i \in p}h[i] mi=i∈pminh[i]
n u m = ∑ i ∈ p [ h [ i ] = m i ] num = \sum_{i\in p}[h[i] = mi] num=∑i∈p[h[i]=mi]
转移:
1:
f ( p , 0 ) = 2 m i ⋅ ∏ v ∈ S f ( v , 0 ) f (p, 0) = 2^{mi} \cdot \prod_{v \in S} f (v,0) f(p,0)=2mi⋅v∈S∏f(v,0)
什么意思呢?考虑后
m
i
+
1
mi + 1
mi+1 行,我们共有
∏
v
∈
S
f
(
v
,
0
)
\prod_{v \in S} f (v,0)
∏v∈Sf(v,0) 种方案使得第
m
i
+
1
mi + 1
mi+1 行红蓝相间。对于每一种方案,在它下面拼上一行
R
B
R
B
…
RBRB…
RBRB… 或
B
R
B
R
…
BRBR…
BRBR… 都是满足要求的,因为这样拼接,每一个小四方格只有四种情况 :
R
B
R
B
B
R
B
R
R
B
B
R
R
B
B
R
RB \ RB \ BR \ BR \\ RB \ BR \ RB \ BR
RB RB BR BRRB BR RB BR
2:
f ( p , 1 ) = f ( p , 0 ) + 2 n u m ⋅ ∏ v ∈ S ( f ( v , 1 ) + f ( v , 0 ) ) − 2 ∏ v ∈ S f ( v , 0 ) f (p, 1) = f (p, 0) + 2^{num} \cdot \prod_{v \in S} (f (v, 1) + f (v, 0)) - 2 \prod_{v \in S} f (v, 0) f(p,1)=f(p,0)+2num⋅v∈S∏(f(v,1)+f(v,0))−2v∈S∏f(v,0)
这个转移又是什么意思呢?
若第 m i mi mi 为红蓝相间
答案即为 f ( p , 0 ) f (p, 0) f(p,0)
若第 m i mi mi 不为红蓝相间
则答案就是只考虑后 m i mi mi 行的方案数(因为 m i mi mi 行之前的只能是上一行取反的结果,它的方案数和只考虑 m i mi mi 行后的方案数相同)。
h [ i ] = m i h[i] = mi h[i]=mi 的列的下一行红色蓝色都可以,所以贡献是 2 n u m 2 ^ {num} 2num
∏ v ∈ S ( f ( v , 1 ) + f ( v , 0 ) ) ( ∗ ) \prod_{v \in S} (f (v, 1) + f (v, 0))(*) ∏v∈S(f(v,1)+f(v,0))(∗) 相当于是:
我们将 S S S 任意分裂成两个集合 S 1 , S 2 S_1, S_2 S1,S2,让这个 S 1 S_1 S1 集合内的元素的对应序列(相当于是第 m i + 1 mi + 1 mi+1 行)的下一行 (即第 m i mi mi 行) 取反,然后 S 2 S_2 S2 集合内的元素的对应序列的下一行相同 (这种情况必须是最后一行红蓝相间),再乘上 h [ i ] = m i h[i] = mi h[i]=mi 的列的贡献( 2 n u m 2^{num} 2num),则后 m i mi mi 行的方案数为 2 n u m ∑ i ∈ S 1 ∑ j ∈ S 2 f ( i , 1 ) ∗ f ( j , 0 ) 2 ^ {num} \sum_{i \in S_1} \sum_{j \in S_2} f (i, 1) * f (j, 0) 2num∑i∈S1∑j∈S2f(i,1)∗f(j,0),发现这个式子化简就是 ( ∗ ) (*) (∗) 式(可以参考二项式展开定理的证明)。
第 m i mi mi 行为红蓝相间的方案数为: 2 ∏ v ∈ S f ( v , 1 ) 2 \prod_{v \in S} f (v, 1) 2∏v∈Sf(v,1),减去它的贡献就行了。
参考代码:
#include <map>
#include <cmath>
#include <queue>
#include <vector>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define db double
#define LL long long
#define ULL unsigned long long
#define PII pair <int, int>
#define MP(x,y) make_pair (x, y)
#define rep(i,j,k) for (int i = (j); i <= (k); i++)
#define per(i,j,k) for (int i = (j); i >= (k); i--)
template <typename T> T Max (T x, T y) { return x > y ? x : y; }
template <typename T> T Min (T x, T y) { return x < y ? x : y; }
template <typename T> T Abs (T x) { return x > 0 ? x : -x; }
template <typename T>
bool read (T &x) {
x = 0; T f = 1;
char ch = getchar ();
while (ch < '0' || ch > '9') {
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9') {
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar ();
}
x *= f;
return 1;
}
template <typename T>
void write (T x) {
if (x < 0) {
putchar ('-');
x = -x;
}
if (x < 10) {
putchar (x + '0');
return;
}
write (x / 10);
putchar (x % 10 + '0');
}
template <typename T>
void print (T x, char ch) {
write (x); putchar (ch);
}
const int Maxn = 100;
const int Inf = 0x3f3f3f3f;
const LL Mod = 1e9 + 7;
int n;
int h[Maxn + 5];
LL f[2][Maxn * 4 + 5];
LL quick_pow (LL x, LL y) {
LL res = 1;
while (y) {
if (y & 1) res = (res * x) % Mod;
x = (x * x) % Mod; y >>= 1;
}
return res;
}
int pool;
int solve (int l, int r, int delta) {
int p = ++pool;
if (l > r) { return p; }
int _min = Inf;
rep (i, l, r)
_min = Min (_min, h[i]);
vector <int> s;
int last = l, cnt = 0;
rep (i, l, r)
if (_min == h[i]) {
if (last <= i - 1)
s.push_back (solve (last, i - 1, _min));
cnt++;//注意这里不要写在if里面,因为cnt应该表示断点个数,而不是分成的区间个数。
last = i + 1;
}
if (last <= r)//注意这里不加1,因为cnt应该表示断点个数,而不是分成的区间个数。
s.push_back (solve (last, r, _min));
LL res1 = 1, res2 = 1;
rep (i, 0, (int)s.size () - 1) {
int v = s[i];
res1 = res1 * f[0][v] % Mod;
res2 = res2 * (f[0][v] + f[1][v]) % Mod;
}
f[0][p] = quick_pow (2, _min - delta) * res1 % Mod;
f[1][p] = (f[0][p] + quick_pow (2, cnt) * res2 % Mod - 2 * res1 % Mod + Mod) % Mod;
return p;
}
int main () {
cin >> n;
rep (i, 1, n) cin >> h[i];
int rt = solve (1, n, 0);
cout << f[1][rt];
return 0;
}