「解题报告」ARC123F Insert Addition
我啥都不会啊唔。
我们考虑不使用数来刻画这个东西,而是使用一个系数对来刻画这个东西,即 \((x, y)\) 表示 \(ax+by\)。那么我们相当于是初始有 \((1, 0), (0, 1)\),每次相邻的两个二元组对应位置相加,即 \((a, b), (a+c, b+d), (c, d)\)。
发现这个过程与 Stern-Brocot 树的构建过程是一模一样的,那么我们其实就是要找 Stern-Brocot 树上满足 \(ax+by \le n\) 的 \(\frac{y}{x}\) 中前 \([L, R]\) 大的数。
求前 \([L, R]\) 大的问题是比较经典的,我们可以先找出第 \(L\) 大的,然后再递推找出 \([L, R]\) 大的。放在 Stern-Brocot 树上,就是先找出第 \(L\) 大的节点位置,然后接着在树上 DFS 找出 \([L, R]\) 大的数。由于 Stern-Brocot 树高 \(O(n)\),那么我们最多只会经过 \(O(n)\) 个无用节点,其他访问到的节点必定是一个答案,所以这部分的复杂度为 \(O(R - L + n)\)。
我们考虑如何找第 \(L\) 大的节点。我们可以在 Stern-Brocot 树上二分,那么我们只需要快速求得每个子树有多少节点即可。
那么我们需要解决的问题为:
求有多少对 \((x, y)\),满足:
- \(\gcd(x, y) = 1\);
- \(ax+by \le n\);
\(\gcd(x, y) = 1\) 的限制显然可以莫反搞掉,简单推式子可以得到答案为:
数论分块套类欧可以做到 \(O(\sqrt n \log n)\) 计算。
然后我们就跑即可。注意 Stern-Brocot 树上跳相同方向的链时需要倍增跳,复杂度 \(O(n + \sqrt{n} \log^3 n)\)。
另一种做法是实数二分,然后再找出分数表示。maspy 的代码中有一个奇妙的递推式子,但是递推式子属实没有看懂,咕了。
类欧实在不想写了,贺了一份 atcoder 的板子。
正解跑的果然比乱搞快,直接最优解了。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 300005;
int a, b, n;
long long l, r;
int mu[MAXN];
int pri[MAXN], vis[MAXN], pcnt;
long long f[MAXN];
namespace atcoder {
namespace internal {
// @param m `1 <= m`
// @return x mod m
constexpr long long safe_mod(long long x, long long m) {
x %= m;
if (x < 0) x += m;
return x;
}
// @param n `n < 2^32`
// @param m `1 <= m < 2^32`
// @return sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)
unsigned long long floor_sum_unsigned(unsigned long long n,
unsigned long long m,
unsigned long long a,
unsigned long long b) {
unsigned long long ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
a %= m;
}
if (b >= m) {
ans += n * (b / m);
b %= m;
}
unsigned long long y_max = a * n + b;
if (y_max < m) break;
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
std::swap(m, a);
}
return ans;
}
}
long long floor_sum(long long n, long long m, long long a, long long b) {
assert(0 <= n && n < (1LL << 32));
assert(1 <= m && m < (1LL << 32));
unsigned long long ans = 0;
if (a < 0) {
unsigned long long a2 = internal::safe_mod(a, m);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
a = a2;
}
if (b < 0) {
unsigned long long b2 = internal::safe_mod(b, m);
ans -= 1ULL * n * ((b2 - b) / m);
b = b2;
}
return ans + internal::floor_sum_unsigned(n, m, a, b);
}
}
long long count(long long a, long long b) {
long long ans = 0;
for (int l = 1, r; l <= n; l = r + 1) {
r = n / (n / l);
if (n / l < a + b) break;
ans += (mu[r] - mu[l - 1]) * atcoder::floor_sum(n / l / a, b, -a, n / l - a);
}
return ans;
}
vector<int> ans;
typedef pair<long long, long long> fraction;
fraction operator+(fraction x, fraction y) {
return { x.first + y.first, x.second + y.second };
}
fraction operator*(long long x, fraction y) {
return { x * y.first, x * y.second };
}
long long calc(fraction x) {
return 1ll * a * x.first + 1ll * b * x.second;
}
void insert(fraction x) {
ans.push_back(calc(x));
}
void solve2(fraction x, fraction y, long long l, long long r) {
if (ans.size() >= r - l + 1) return;
long long mid = calc(x + y);
if (mid > n) return;
solve2(x, x + y, l, r);
if (ans.size() < r - l + 1) {
insert(x + y);
}
solve2(x + y, y, l, r);
}
void solve1(fraction x, fraction y, long long l, long long r) {
if (ans.size() >= r - l + 1) return;
long long a = calc(x), b = calc(y), mid = calc(x + y);
if (mid > n) return;
long long cnt = count(calc(x), mid) + 1;
if (l == cnt) {
insert(x + y);
solve2(x + y, y, l, r);
return;
} else if (l < cnt) {
int dep = 0;
for (int i = 20; i >= 0; i--) {
if (l < count(a, calc((dep + (1 << i)) * x + y)) + 1) {
dep += (1 << i);
}
}
solve1(x, dep * x + y, l, r);
for (int i = dep; i >= 1; i--) {
if (ans.size() < r - l + 1) {
insert(i * x + y);
solve2(i * x + y, (i - 1) * x + y, l, r);
} else break;
}
} else {
int dep = 0;
long long tot = count(a, b);
for (int i = 20; i >= 0; i--) {
if (tot - l + 1 < count(calc(x + (dep + (1 << i)) * y), b) + 1) {
dep += (1 << i);
}
}
cnt = tot - count(calc(x + dep * y), b);
solve1(x + dep * y, y, l - cnt, r - cnt);
}
}
int main() {
scanf("%d%d%d%lld%lld", &a, &b, &n, &l, &r);
mu[1] = 1;
for (int i = 2; i <= n; i++) {
if (!vis[i]) {
pri[++pcnt] = i;
mu[i] = -1;
}
for (int j = 1; j <= pcnt && i * pri[j] <= n; j++) {
vis[i * pri[j]] = 1;
if (i % pri[j] == 0) {
break;
} else {
mu[i * pri[j]] = -mu[i];
}
}
mu[i] += mu[i - 1];
}
long long cnt = count(a, b);
if (l != cnt + 2 && r != 1) {
solve1({ 1, 0 }, { 0, 1 }, max(l - 1, 1ll), min(r - 1, cnt));
}
if (l == 1) printf("%d ", a);
for (int i : ans) {
printf("%d ", i);
}
if (r == cnt + 2) printf("%d ", b);
printf("\n");
return 0;
}