【大联盟】20230713 T1 方向矩阵(rect) 题解 CF1666A 【Admissible Map】
题目描述
here。
题解
赛时得分:60/100。
想到了正解,但调不出来,就改写暴力了。。。
首先,我们把问题转化成每个点都入度为 \(1\)。
我们考虑合法子串只有两种形式:
注意到 U
和 D
,要么同时出现,要么同时不出现,因为如果存在 U
,就说明 U
所在这一行得到度数减少了,一定需要上一行 D
来弥补。
-
不存在
U
、D
。答案形如RLRL...RLRL
,这是好统计的。 -
存在
U
、D
。考虑第一个U
一定会匹配第一个未被L
和R
得到入度的点。这是因为,首先,由于有U
、D
,则一定会有未被L
和R
得到入度的点,那这个点肯定会被U
相连,因为如果被D
相连,那这个D
之前肯定会有点未被L
和R
得到入度,因为边数 < 点数,与我们定义的第一个未被L
和R
得到入度的点相矛盾。现在,我们就得到了宽度 \(L\)。然后,我们考虑哈希来判断是否合法(赛时想到的是
bitset
巨难写……)。然后,我们考虑按 \(L\) 根号分治:
-
对于 \(L\le\sqrt{n}\),由于不超过 \(\sqrt{n}\) 种,所以我们预处理出哈希值,然后使用
unordered_map
求答案。 -
对于 \(L>\sqrt{n}\),由于答案不超过 \(\sqrt{n}\),所以我们可以暴力往后跳,判断答案。
-
时间复杂度 \(O(n\sqrt{n})\)。
代码
由于 CF 上 \(n\le 2\times 10^4\),模拟赛虽然 \(n\le 10^5\),但听说数据很水,所以直接写了个 \(>\sqrt{n}\) 的部分就摆了。
#include <bits/stdc++.h>
#define SZ(x) (int) x.size() - 1
#define ms(x, y) memset(x, y, sizeof x)
#define all(x) x.begin(), x.end()
#define F(i, x, y) for (int i = (x); i <= (y); ++i)
#define DF(i, x, y) for (int i = (x); i >= (y); --i)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template <typename T> inline void chkmax(T &x, T y) { x = max(x, y); }
template <typename T> inline void chkmin(T &x, T y) { x = min(x, y); }
template <typename T> inline void read(T &x) {
x = 0; int f = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
x *= f;
}
const int N = 1e5 + 10, base = 312349, MOD = 1000000021;
int n, lst, pw[N], sum[4][N], ss[N];
ll ans;
int power(int x, int y = MOD - 2) {
int ans = 1;
for (; y; x = (ll) x * x % MOD, y >>= 1)
if (y & 1) ans = (ll) ans * x % MOD;
return ans;
}
void add(int &x, int y) { if ((x += y) >= MOD) x -= MOD; }
signed main() {
// freopen("rect.in", "r", stdin);
// freopen("rect.out", "w", stdout);
string st; cin >> st;
n = st.size(); st = ' ' + st;
int invbase = power(base);
pw[0] = 1;
F(i, 1, n) {
pw[i] = (ll) pw[i - 1] * base % MOD;
ss[i] = (ss[i - 1] + pw[i]) % MOD;
F(j, 0, 3) sum[j][i] = sum[j][i - 1];
if (st[i] == 'U') add(sum[0][i], pw[i]);
if (st[i] == 'D') add(sum[1][i], pw[i]);
if (st[i] == 'L') add(sum[2][i], pw[i]);
if (st[i] == 'R') add(sum[3][i], pw[i]);
}
for (int i = 2; i <= n; i += 2) {
if (st[i] == 'L' && st[i - 1] == 'R') lst++;
else lst = 0;
ans += lst;
}
lst = 0;
for (int i = 3; i <= n; i += 2) {
if (st[i] == 'L' && st[i - 1] == 'R') lst++;
else lst = 0;
ans += lst;
}
int pos = 1, pp = 1;
F(i, 1, n) {
chkmax(pos, i), chkmax(pp, i);
while (pos <= n && st[pos] != 'U') pos++;
int tp = - 1;
if (st[i + 1] == 'L') {
tp = pp;
pp = i;
}
while (pp <= n && (st[pp + 1] == 'L' || (pp != i && st[pp - 1] == 'R'))) pp++;
if (pos > n) break;
if (pos == i) continue;
if (pp >= pos) continue;
int len = pos - pp;
// if (len > B) {
int inv = power(invbase, len), pw = power(base, len);
for (int l = i, r = i + len - 1; r <= n; l += len, r += len) {
if (st[l] == 'L' || st[r] == 'R') break;
int val = (((ll) (sum[0][r] - sum[0][i - 1]) * inv + (ll) (sum[1][r] - sum[1][i - 1]) * pw + (ll) (sum[2][r] - sum[2][i - 1]) * invbase + (ll) (sum[3][r] - sum[3][i - 1]) * base) % MOD + MOD) % MOD;
if (r >= pos && val == (ss[r] - ss[i - 1] + MOD) % MOD) ans++;
}
// }
if (~tp) pp = tp;
}
cout << ans;
return 0;
}