[暴力 Trick] 根号分治
根号分治
PS:本篇博客题目分析及内容(除代码)均来自于paulzrm
根号分治,是暴力美学的集大成体现。与其说是一种算法,我们不如称它为一个常用的trick。
首先,我们引入一道入门题目 CF1207F Remainder Problem:
给你一个长度为 $5\times10^5$ 的序列,初值为 $0$ ,你要完成 $q$ 次操作,操作有如下两种:
1 x y
: 将下标为 $x$ 的位置的值加上 $y$。2 x y
: 询问所有下标模 $x$ 的结果为 $y$ 的位置的值之和。
考虑这题的暴力是什么。
首先有一种暴力就是按照题目所说的去做,开一个 $5\times10^5$ 大小的数组 $a$ 去存,$1$ 操作就对 $a_x$ 加上 $y$,$2$ 操作就枚举所有下标模 $x$ 的结果为 $y$ 的位置,统计他们的和。
对于这种暴力,$1$ 操作的时间复杂度为 $O(1)$,$2$ 操作的时间复杂度为 $O(n)$,所以在最坏情况下总时间复杂度可达 $O(nq)$。
经过思考,我们可以发现另外一种暴力:新开一个大小为 $n\times n$ 的二维数组 $b$,$b_{i,j}$ 当前所有下标模 $i$ 的结果为 $j$ 的数的和是什么。对于每个 $1$ 操作,动态的去维护这个 $b$ 数组,在每次询问的时候直接输出答案即可。
对于这种暴力,$1$ 操作的时间复杂度是枚举模数的 $O(n)$ ,$2$ 操作的时间复杂度为 $O(1)$,总的时间复杂度为 $O(nq)$。
现在我们发现,这两种暴力对应了两种极端:一个是 $1$ 操作的时间复杂度为 $O(1)$,$2$ 操作的时间复杂度为 $O(n)$;另一个是 $1$ 操作的时间复杂度是枚举模数的 $O(n)$,$2$ 操作的时间复杂度为 $O(1)$。那么,有没有办法让这两种暴力融合一下,均摊时间复杂度,达到一个平衡呢?
其实是有的。我们设定一个阈值 $b$。
对于所有 $\le b$ 的数,我们动态的维护暴力 $2$ 的 $b$ 数组。每次 $1$ 操作只需要枚举 $b$ 个模数即可,故单次操作 $1$ 的时间复杂度降为 $O(b)$。
对于所有 $>b$ 的数,我们就不在操作 $1$ 中维护 $b$,直接再询问答案时暴力枚举下标即可。显然,这 $n$ 个下标中最多有 $\lceil \frac{n}{b}\rceil$ 个下标对 $x$ 取模余 $y$ 找到第一个 $y$ 后每次跳 $x$,即可做到单次操作 $2$ 时间复杂度为 $O(\frac{n}{b})$。
所以,总时间复杂度就成为了 $O(q\times(b+\frac{n}{b}))$。由基本不等式可得,$b+\frac{n}{b} \geq 2\sqrt{b\times\frac{n}{b}}=2\sqrt{n}$,当 $b=\sqrt{n}$ 时取等。所以我们只需要让 $b=\sqrt{n}$,就可以做到时间和空间复杂度均为 $O(q\sqrt{n})$ 的优秀算法了,可以通过此题。
#include<bits/stdc++.h>
#define rint register int
#define endl '\n'
const int N = 8e2 + 5;
const int M = 5e5 + 5;
using namespace std;
int s[N][N], a[M];
signed main()
{
int q;
cin >> q;
int n = sqrt(500000);
while(q--)
{
int op, x , y;
cin >> op >> x >> y;
if (op == 1)
{
for(rint i = 1 ; i < n; i++)
{
s[i][x % i] += y;
}
a[x] += y;
}
if (op == 2)
{
if(x < n)
{
cout << s[x][y] << endl;
}
else
{
int res = 0;
for(rint i = y; i <= 500000; i += x)
{
res += a[i];
}
cout << res << endl;
}
}
}
return 0;
}
CF710D Two Arithmetic Progressions
题目大意:
现在有两个等差数列,形如 $a_1k+b_1$ 和 $a_2k+b_2$,其中 $k$ 要满足是自然数。现在再给你两个正整数 $l,r$,求出 $[l,r]$ 间有多少个数同时出现在两个等差数列中。数据满足 $0<a_1,b_1\le2\times10^9,-2\times10*9\le b_1,b_2,l,r\le 2\times10^9,l\le r$。
题解:
正解要用到 exgcd 等数论知识,且细节较多比较麻烦。现在我们考虑如何用根号分治解决该数论问题。
现在钦定 $a_1\geq a_2$,再令 $t=\sqrt{2\times 10^9}$。
$a_1\le t$。此时 $a_2$ 也 $\le t$。由于每隔 $lcm(a_1,b_1)$ 就是一个循环节,且每个循环节只会有 $1$ 的贡献,我们只需要找到第一个重合的数(或报告不存在),然后计算出循环节的个数就可以了。找到第一个重合的数,可以直接对着第一个等差数列从前往后跳,如果跳了 $a_2$ 次还是没有出现,可以证明一定不存在了。
$a_1>t$。那么有 $\frac{2\times10^9}{a_1}\le t$。也就是说,在 $[l,r]$ 这段区间内,属于等差数列 $1$ 的数不会超过根号个。我们只需要枚举这个根号个数,依次判断其是否在等差数列 $2$ 中即可。
#include <bits/stdc++.h>
#define rint register int
#define int long long
#define endl '\n'
using namespace std;
const int N = 2e9;
int a1, a2, b1, b2;
int l, r;
int n;
signed main()
{
n = sqrt(N);
cin >> a1 >> a2 >> b1 >> b2;
cin >> l >> r;
int m = max(a1, b1);
if (m <= n)
{
for (rint i = -m * 2; i <= m * 2; i++)
{
int p = i * a1 + a2;
if ((abs(b2 - p) % b1 == 0))
{
int k = __gcd(a1, b1);
int lcm = a1 * b1 / k;
int begin = max(max(a2, b2), l);
if (p > begin)
{
p = begin + (p - begin) % lcm;
}
else
{
p += ((begin - p) / lcm + 1) * lcm;
p = begin + (p - begin) % lcm;
}
if (r < p)
{
continue;
}
cout << (r - p) / lcm + 1 << endl;
return 0;
}
}
cout << 0 << endl;
return 0;
}
else
{
if (a1 < b1)
{
swap(a1, b1);
swap(a2, b2);
}
int cnt = 0;
for (rint i = a2; i <= r; i += a1)
{
if (i >= l && i >= b2)
{
if ((i - b2) % b1 == 0)
{
cnt++;
}
}
}
cout << cnt << endl;;
}
return 0;
}
题目大意:
给定两个正整数 $K,M (1\le K,M \le 10^{10})$,你需要求出有多少个正整数 $N$ 满足 $1 \le N \le M$ 且 $N \equiv S_N (\mod K) $,其中 $S_N$ 是 $N$ 的各位数字之和。
题解:
这个 $10^{10}$ 的数据范围并不常见,但是可以发现大概是根号的复杂度。
显然无法分块,考虑怎么做到根号分治。我们先对 $K$ 设定一个阈值 $T$,其中 $T$ 是 $\sqrt{M}$ 级别。
- $K \ge T$
当 $1 \le N \le 10^{10}$ 时,最大的 $S_N$ 不过 $9\times 10=90$,所以我们可以去先枚举数字和 $S$,然后就可以发现,$\mod K= S$ 的 $N$ 的个数不会超过 $\lfloor\frac{M}{K}\rfloor +1$ 个。直接枚举这些数就可以了。复杂度 $O(90\times \frac{M}{K})$。
- $K \le T$
我们可以考虑把 $K$ 做为一维压到数位 dp 里了。令 $dp_{i,j,sm,0/1}$ 表示考虑到从高到低第 $i$ 位,此时的数 $\mod K = j$,数字和为 $sm$,是否已经小于 $m$ 的数的个数。这样就可以 dp 了,复杂度 $O(10\times90\times K\times 10)=O(9000K)$。
#include <bits/stdc++.h>
#define rint register int
#define int long long
#define endl '\n'
using namespace std;
const int N = 1e4 + 5;
const int M = 1e8;
const int K = 1e2 - 10;
const int W = 1e1 + 2;
int k, m, n;
int len;
int a[W];
int f[W][N][K][2];
int ans;
signed main()
{
cin >> k >> m;
len = sqrt(M);
int tool = m;
for (rint i = 1; ;i++)
{
a[i] = tool % 10;
tool /= 10;
if (tool == 0)
{
n = i;
break;
}
}
reverse(a + 1, a + n + 1);
if (k >= len)
{
for (rint i = 0; i <= K; i++)
{
for (rint j = i; j <= m; j += k)
{
int t = j;
int cnt = 0;
while (t)
{
cnt += t % 10;
t /= 10;
}
if (cnt % k == i)
{
ans++;
}
}
}
cout << ans - 1 << endl;
return 0;
}
f[1][0][0][0] = 1;
for (rint i = 1; i <= n; i++)
{
for (rint j = 0; j < k; j++)
{
for (rint o = 0; o <= K; o++)
{
for (rint t = 0; t < a[i]; t++)
{
if (o + t <= K)
{
f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][0];
}
}
if (o + a[i] <= K)
{
f[i + 1][(j * 10 + a[i]) % k][o + a[i]][0] += f[i][j][o][0];
}
for (rint t = 0; t < 10; t++)
{
if (o + t <= K)
{
f[i + 1][(j * 10 + t) % k][o + t][1] += f[i][j][o][1];
}
}
}
}
}
for (rint i = 0; i <= K; i++)
{
ans += f[n + 1][i % k][i][0];
ans += f[n + 1][i % k][i][1];
}
cout << ans - 1 << endl;
return 0;
}