2024牛客寒假算法基础集训营4 H&K
H
观察下图
1. 只有在横着连续有三个*的时候才可能会出现三角形,并且随着横坐标的增加实际上增加的是(从左往右从上往下方向)斜对角线上点的数量。
2. 当横着连续有3-4个的时候斜线的长度为2,当横着又连续5-6个的时候斜线的长度为3,以此类推,所以启发使用斜对角前缀和来快速解决每一行连续*答案的更新。
3. 但是每个这个斜对角上的对于当前这一行不一定是存在的(例如斜对角上的某个只能斜着延伸到上上行,到上一行没有*接它延伸到当前行了。),所以有些时候需要添加进前缀和,有些时候需要将其从前缀和删除,所以启发用树状数组动态维护斜对角前缀和。
4. 接下来思考如何判断每一个斜对角上*的存活时间。
5. 维护左右斜对角的数量前缀和,通过二分可以在n * n * logn 的时间内预处理出每个往下延伸的最远合法距离d。考虑这个的合法存在时间,从这个所在的第i行开始到i + d - 1行是合法存在的。所以在遍历完第i行之后就可以把第i行所有的加入斜对角的树状数组里面,然后再将在第i行过期的从树状数组里面给删掉。
6. 整个矩阵的斜对角可以考虑分为n + m - 1类,例如4 * 6的可以按如下数字分类。
这题空间限制比较严重,不支持再开一个数组记录类型。观察发现当横坐标x大于等于纵坐标y的时候当前格子的类型就是x - y + 1;当横坐标x小于纵坐标y的时候格子的类型就是y - x + 1 + n。
下图为不同类的下标。
观察发现当横坐标x大于等于纵坐标y的时候当前格子的下标就为y;当横坐标x小于纵坐标y的时候格子的下标就为x。
#include <bits/stdc++.h>
using ll = long long;
typedef std::pair<int, int> PII;
typedef std::array<ll, 3> ay;
const int N = 3e5 + 10, M = N << 1;
const int INF = 0x3f3f3f3f;
const ll inf = 0x3f3f3f3f3f3f3f;
const int MOD = 1e9 + 7;
#define ls u << 1
#define rs u << 1 | 1
int n, m;
int a[N];
inline void solve() {
std::cin >> n >> m;
std::vector<std::vector<char>> G(n + 2, std::vector<char>(m + 2));
std::vector<std::vector<int>> lsum(n + 2, std::vector<int>(m + 2)),
rsum(n + 2, std::vector<int>(m + 2)),
tr(n + m + 2, std::vector<int>(n + 2, 0));
std::vector<std::vector<PII>> del(n + 2, std::vector<PII>());;
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= m; j ++)
std::cin >> G[i][j];
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= m; j ++) {
if (G[i][j] == '*') lsum[i][j] = rsum[i][j] = 1;
else lsum[i][j] = rsum[i][j] = 0;
}
for (int i = 2; i <= n; i ++)
for (int j = 1; j < m; j ++)
lsum[i][j] += lsum[i - 1][j + 1];
for (int i = 2; i <= n; i ++)
for (int j = 2; j <= m; j ++)
rsum[i][j] += rsum[i - 1][j - 1];
for (int i = 1; i <= n; i ++) {
for (int j = 1; j <= m; j ++) {
if (G[i][j] == '*') {
int l = 1, r = std::min(j, m - j + 1);
r = std::min(r, n - i + 1);
while (l < r) {
int mid = l + r + 1 >> 1;
int dx1 = i + mid - 1, dy1 = j - mid + 1;
int dx2 = i + mid - 1, dy2 = j + mid - 1;
if (lsum[dx1][dy1] - lsum[i - 1][j + 1] == mid && rsum[dx2][dy2] - rsum[i - 1][j - 1] == mid) l = mid;
else r = mid - 1;
}
del[i + l - 1].push_back({i, j});
}
}
}
auto query = [&](int pos, int x) -> int {
int res = 0;
for (int i = x; i; i -= (i & (-i))) res += tr[pos][i];
return res;
};
auto add = [&](int pos, int x, int v) -> void{
for (int i = x; i <= m + 1; i += (i & (-i)))
tr[pos][i] += v;
};
auto querylr = [&](int pos, int l, int r) -> int {
return query(pos, r) - query(pos, l - 1);
};
auto getpos = [&](int x, int y) -> int {
if (x >= y) return x - y + 1;
return y - x + 1 + n;
};
auto getx = [&](int x, int y) -> int {
if (x >= y) return y;
return x;
};
ll ans = 0;
for (int i = 1; i <= n; i ++) {
int l = 1, r = 1;
while (l <= m) {
r = l;
int into = 1;
if (G[i][l] == '*')
while (r + 1 <= m && G[i][r + 1] == '*') {
r ++;
if (r - l > 1) {
int c = into;
int dx1 = std::max(i - c, 1), dy1 = r - c;
ans += querylr(getpos(i, r), getx(dx1, dy1), getx(i, r));
if ((r - l) & 1 ) into ++;
}
}
l = r + 1;
}
for (int j = 1; j <= m; j ++) if (G[i][j] == '*') add(getpos(i, j), getx(i, j), 1);
for (auto [x, y]: del[i]) {
add(getpos(x, y), getx(x, y), -1);
}
}
std::cout << ans << '\n';
}
signed main(void) {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int _ = 1;
//std::cin >> _;
while (_ --) solve();
return 0;
}
K
很显然的线段树题,很明显合并左右两个区间的时候如果右区间遇到B之前有k个R,那么两个区间的总和相加还需要加上左区间最右侧点上的方块数量乘上2^k - 最右侧点上的方块数量。
这题的难点在于线段树维护的信息和pushup,直接看代码即可。
#include <bits/stdc++.h>
using ll = long long;
typedef std::pair<int, int> PII;
typedef std::array<ll, 3> ay;
const int N = 2e5 + 10, M = N << 1;
const int INF = 0x3f3f3f3f;
const ll inf = 0x3f3f3f3f3f3f3f;
const int MOD = 1e9 + 7;
#define ls u << 1
#define rs u << 1 | 1
int n, m;
int a[N];
std::string str;
struct node {
ll sum, tot;//当前区间最右边的方块数量, 总方块数量
bool blue;//是否出现蓝色导致变换行
int count;//蓝色前有多少红色
}tr[N << 2];
ll fastpow(ll x) {
ll cur = 1;
ll t = 2;
while (x) {
if (x & 1) cur = cur * t % MOD;
t = t * t % MOD;
x >>= 1;
}
return cur;
}
inline void pushup(node &root, node &lst, node &rst) {
ll cur = lst.sum * fastpow(rst.count) % MOD;
root.tot = (rst.tot + cur + lst.tot - lst.sum + MOD) % MOD;
root.blue = lst.blue | rst.blue;
if (lst.blue) root.count = lst.count % MOD;
else if (rst.blue) root.count = lst.count + rst.count;
else root.count = lst.count + rst.count;
if (rst.blue) root.sum = rst.sum;
else root.sum = (rst.sum + cur) % MOD;
}
inline void pushup(int u) {
pushup(tr[u], tr[ls], tr[rs]);
}
inline void build(int u, int L, int R) {
if (L == R) {
tr[u].count = 0;
tr[u].blue = false;
tr[u].sum = tr[u].tot = 1;
if (str[L] == 'B') tr[u].blue = true;
else if (str[L] == 'R') tr[u].count = 1;
return ;
}
int mid = L + R >> 1;
build(ls, L, mid);
build(rs, mid + 1, R);
pushup(u);
}
inline void modify(int u, int L, int R, int x, char ch) {
if (L == R) {
tr[u].count = 0;
tr[u].blue = false;
tr[u].sum = tr[u].tot = 1;
if (ch == 'B') tr[u].blue = true;
else if (ch == 'R') tr[u].count = 1;
return ;
}
int mid = L + R >> 1;
if (x <= mid) modify(ls, L, mid, x, ch);
else modify(rs, mid + 1, R, x, ch);
pushup(u);
}
inline node query(int u, int L, int R, int l, int r) {
if (L >= l && R <= r) return tr[u];
int mid = L + R >> 1;
if (r <= mid) return query(ls, L, mid, l, r);
else if (l > mid) return query(rs, mid + 1, R, l, r);
node lst = query(ls, L, mid, l, r);
node rst = query(rs, mid + 1, R, l, r);
node root;
pushup(root, lst, rst);
return root;
}
inline void solve() {
std::cin >> n >> m;
std::cin >> str;
str = " " + str;
build(1, 1, n);
while (m --) {
int op;
std::cin >> op;
if (op == 2) {
int l, r;
std::cin >> l >> r;
node cur = query(1, 1, n, l, r);
std::cout << cur.tot << '\n';
} else {
int pos;
char ch;
std::cin >> pos >> ch;
modify(1, 1, n, pos, ch);
}
}
}
signed main(void) {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int _ = 1;
//std::cin >> _;
while (_ --) solve();
return 0;
}