@bzoj - 4377@ [POI2015] Kurs szybkiego czytania
@description@
给定 n, a, b, p,其中 n, a 互质。定义一个长度为 n 的 01 串 c[0..n-1],其中 c[i] == 0 当且仅当 (ai+b) mod n < p。
给定一个长为 m 的小 01 串,求出小串在大串中出现了几次。
input
第一行包含整数 n, a, b, p, m (2<=n<=10^9, 1<=p, a, b, m<n, 1<=m<=10^6)。n 和 a 互质。
第二行一个长度为 m 的 01 串。
output
一个整数,表示小串在大串中出现了几次
sample input
9 5 6 4 3
101
sample output
3
sample explain
@solution@
假如某次出现,小 01 串的第 1 个位置对应大 01 串的第 i 个位置(i <= n - m),令 x = (ai + b) mod n。
如果小 01 串的第 1 个为 0,则 x < p;否则 x >= p。
如果小 01 串的第 2 个为 0,则 (x + a) mod n < p;否则 (x + a) mod n >= p。
……
如果小 01 串的第 i 个为 0,则 (x + a*i) mod n < p;否则 (x + a*i) mod n >= p。
我们可以解模意义下的不等式,将所有不等式的解集取交集就可以求出满足条件的 x,然后判断 x 对应的 i 是否满足 i <= n - m。
但是这个算法虽然正确,细节方面仍然不是很清晰。
解不等式可以分类讨论,也可以这样来理解:
对于 (x + a*i) mod n < p,可以理解为由点 x 出发在模意义下移动 a*i 步到达区间 [0, p - 1]。
我们可以相对运动:区间 [0, p - 1] 在模意义下反向移动 a*i 步,就是不等式 (x + a*i) mod n < p 的解集。
由于是在模意义下移动,最终的解集区间形如 [l, r] 或 [l, n)∪[0, r]。
然而问题就来了:如果一些解集长成 [l, n)∪[0, r] 的形式,所有解集的交集就可能长得非常的奇怪(因为我们相当于求集合并集的交集),如果暴力取交我们可能需要 O(n^2) 的时间完成。
一种简洁的解决方案是用补集思想。我们知道,集合先取交集再取补集 = 集合先取补集再取并集。
我们一开始将所有解集取补集(因为是区间形式,这个很容易实现),然后就是区间并问题,排序+扫描一遍可以搞定。最后将集合并集再取补集就可以了。
同时,这也允许我们将 i > n-m 这些不合法的情况当作合法情况的补集,跟上面的一起取并集。
@accepted code@
#include<set>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN = 1000000;
struct segment{
int le, ri;
segment(int _l=0, int _r=0):le(_l), ri(_r){}
}seg[3*MAXN + 5];
bool operator < (segment a, segment b) {
return a.le < b.le;
}
char s[MAXN + 5];
int main() {
int n, a, b, p, m;
scanf("%d%d%d%d%d", &n, &a, &b, &p, &m);
scanf("%s", s);
int cnt = 0;
for(int i=0;i<m;i++) {
int tmp = (n - 1LL*i*a%n)%n, le, ri;
if( s[i] == '0' )
le = tmp%n, ri = (p + tmp - 1)%n;
else le = (p + tmp)%n, ri = (n + tmp - 1)%n;
if( le <= ri )
seg[++cnt] = segment(0, le - 1), seg[++cnt] = segment(ri + 1, n - 1);
else seg[++cnt] = segment(ri + 1, le - 1);
}
for(int i=n-1;i>n-m;i--) {
int tmp = (1LL*a*i%n + b)%n;
cnt++, seg[cnt].le = tmp, seg[cnt].ri = tmp;
}
cnt++, seg[cnt].le = n, seg[cnt].ri = n;
sort(seg + 1, seg + cnt + 1);
int ans = 0, ri = -1;
for(int i=1;i<=cnt;i++) {
if( seg[i].le <= ri )
ri = max(ri, seg[i].ri);
else {
ans += max(0, seg[i].le - ri - 1);
ri = seg[i].ri;
}
}
printf("%d\n", ans);
}
@details@
一开始想得很简单,没有想到 [l, n)∪[0, r] 这样的集合取交集会出现断断续续的区间,弄得我……想要吐血了。
后来的做法也是用了补集,但只有 [l, n)∪[0, r] 才取补集,因为它们取补集之后就变成了一个正常的线段 [r + 1, l - 1] 了。
实测上面的做法要比全部去补集的常数小一些。