类欧几里得算法
类欧几里得算法
最近读具体数学第二、三章有感,遂来挑战一下这题。这题做完以后感觉对求和技巧与规约的理解又更进了一步。
听说有个叫万能欧几里得算法的东西,是这类题目的通解做法,奈何本人水平太菜,有空再补。
类欧几里得算法要求以下三个式子的值:
我们将每个式子设成函数的用处很快就能体现出来,因为类欧几里得算法之所以这么命名,不是因为其推导过程和欧几里得算法有类似之处,而是因为其递归方式和复杂度证明方式与欧几里得算法一致。
算 \(f\)
先考虑最简单的第一个式子如何求,这里说简单只是式子最简单,而其操作方式并不简单,如果明白了如何递归求解 \(f(n, a, b,c)\) 后面两个式子的推导思路就大差不差了。
根据具体数学的启示,我们先观察该式的边界情况,即 \(a=0\) 或 \(b=0\),发现实际上只有 \(a=0\) 的答案是平凡的,\(b=0\) 也不太好处理。当 \(a=0\) 时,\(f(n,a,b,c) = (n+1)\lfloor \dfrac bc\rfloor\)。
既然 \(a=0\) 是好处理的,那么我们能否像欧几里得算法那样不断递归,将任何 \(a\) 都消成 \(0\) 呢?、
首先考虑 \(a\ge c\) 或者 \(b\ge c\) 的情况,受到具体数学 \(P77\) 对于类似的一个和式的处理的启发,我们能够将 \(a\) 与 \(b\) 拆分成 \(\lfloor \dfrac{a}{c}\rfloor c+ (a\bmod c)\) 与 \(\lfloor \dfrac{b}{c}\rfloor c+ (b\bmod c)\),使得原式的范围缩小,从而得到规约的效果:
不断执行上面的操作只能让 \(a< c\) 且 \(b <c\),接下来该怎么进一步缩小范围呢?
当 \(a<c\) 且 \(b<c\) 时,我们刚才换 \(a,b\) 的方法显然是行不通了,我们不妨换一种求和的思路,还记得具体数学 \(P73\) 的思想吗?我们不妨用 \(\sum_j [1\le j\le x]\) 来替换 \(\lfloor x\rfloor\)。
有两点需要注意:
- 第三步的代换十分巧妙,需要好好体会,因为这么做可以使区间变为左开右闭区间,这样统计区间内整数个数时减去的是一个下取整函数,而这个下取整函数的形式恰好可以规约
- \(b\) 对我们而言作用不大,我们主要想让这个类欧几里得过程最终能够到达 \(a=0\) 的边界情况,而我们这部分的操作相当于将 \(a,c\) 位置互换,从而使得 \(a\) 始终在减小,故而 \(a\) 最终总会到达 \(0\),其实如果仔细看的话,只观察 \(a,c\) 的变化,就等同于欧几里得算法,根据主定理可以分析复杂度的确是 \(O(\log n)\)
总结来说,
算 \(g\) 与 \(f\)
为什么 \(g\) 和 \(f\) 我并到一块了呢?看完后面的推导,我们就能够发现这两者是互推的。
先考虑 \(g\) 的求法,同样,先讨论 \(a=0\) 的平凡情况,此时 \(g(n, a, b,c)= (n+1)\lfloor\dfrac{b}{c}\rfloor ^2\)
当 \(a\ge c\) 或者 \(b\ge c\) 时,推导差不多,只需要将平方展开即可:
当 \(a<c\) 且 \(b<c\) 时,推导需要运用到具体数学 \(P31\) 页的 \((2.33)\) 变换和 \(P39\) 的展开和收缩变换:
最后考虑 \(h\) 的求法,\(a=0\) 时,\(h(n,a,b,c) = \dfrac{n(n+1)}2 \lfloor\dfrac bc\rfloor\)
当 \(a\ge c\) 或 \(b\ge c\) 时,推导仍然比较简单,暴力展开就行:
当 \(a<c\) 且 \(b<c\) 时,和之前一样:
推导工作全部进行完毕。
发现 \(g,h\) 是互推的,\(f\) 是独立的,即可写出程序。
然而常数过大,无法接受,但是,我们欣喜地发现,在同一种 \(a,b,c\) 的关系之下,\(f,g,h\) 所需的递归都是一致的,因此我们不妨同步算 \(f,g,h\),一起计算的时间复杂度为 \(O(\log n)\)
附上代码供参考:
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int mod = 998244353;
ll inv2 = 499122177, inv6 = 166374059;
inline ll mul(ll a, ll b) { return (a * b) % mod; }
inline ll add(ll a, ll b) { return (a + b) % mod; }
inline ll sqr(ll a) { return mul(a, a); }
/*
ll f(ll n, ll a, ll b, ll c); // sum (ai + b) / c
ll g(ll n, ll a, ll b, ll c); // sum ((ai + b) / c) ^ 2
ll h(ll n, ll a, ll b, ll c); // sum i((ai + b) / c)
ll f(ll n, ll a, ll b, ll c) {
if (!a) {
return mul(n + 1, b / c);
} else if (a >= c || b >= c) {
ll co1 = f(n, a % c, b % c, c), co2 = mul(mul(mul(a / c, n), n + 1), inv2);
ll co3 = mul(b / c, n + 1);
return add(co1, add(co2, co3));
} else {
ll m = (a * n + b) / c;
return add(mul(m, n), mod - f(m - 1, c, c - b - 1, a));
}
}
ll g(ll n, ll a, ll b, ll c) {
if (!a) {
return mul(n + 1, sqr(b / c));
} else if (a >= c || b >= c) {
ll co1 = g(n, a % c, b % c, c), co2 = mul(sqr(a / c), mul(n, mul(n + 1, mul(2 * n + 1, inv6))));
ll co3 = mul(n + 1, sqr(b / c)), co4 = mul(2, mul(a / c, h(n, a % c, b % c, c)));
ll co5 = mul(a / c, mul(b / c, mul(n, n + 1))), co6 = mul(2, mul(b / c, f(n, a % c, b % c, c)));
return add(co1, add(co2, add(co3, add(co4, add(co5, co6)))));
} else {
ll m = (a * n + b) / c;
ll co1 = mul(n, mul(m, m + 1)), co2 = mod - mul(2, h(m - 1, c, c - b - 1, a));
ll co3 = mod - mul(2, f(m - 1, c, c - b - 1, a)), co4 = mod - f(n, a, b, c);
return add(co1, add(co2, add(co3, co4)));
}
}
ll h(ll n, ll a, ll b, ll c) {
if (!a) {
return mul(b / c, mul(n, mul(n + 1, inv2)));
} else if (a >= c || b >= c) {
ll co1 = h(n, a % c, b % c, c), co2 = mul(a / c, mul(n, mul(n + 1, mul(2 * n + 1, inv6))));
ll co3 = mul(b / c, mul(n, mul(n + 1, inv2)));
return add(co1, add(co2, co3));
} else {
ll m = (a * n + b) / c;
ll co1 = mul(m, mul(n, mul(n + 1, inv2))), co2 = mod - mul(g(m - 1, c, c - b - 1, a), inv2);
ll co3 = mod - mul(f(m - 1, c, c - b - 1, a), inv2);
return add(co1, add(co2, co3));
}
}
*/
struct node { ll f, g, h; };
node solve(ll n, ll a, ll b, ll c) {
node ret; ll co1, co2, co3, co4, co5, co6;
if (!a) {
ret.f = mul(n + 1, b / c);
ret.g = mul(n + 1, sqr(b / c));
ret.h = mul(b / c, mul(n, mul(n + 1, inv2)));
} else if (a >= c || b >= c) {
node tmp = solve(n, a % c, b % c, c);
co1 = tmp.f, co2 = mul(mul(mul(a / c, n), n + 1), inv2), co3 = mul(b / c, n + 1);
ret.f = add(co1, add(co2, co3));
co1 = tmp.g, co2 = mul(sqr(a / c), mul(n, mul(n + 1, mul(2 * n + 1, inv6))));
co3 = mul(n + 1, sqr(b / c)), co4 = mul(2, mul(a / c, tmp.h));
co5 = mul(a / c, mul(b / c, mul(n, n + 1))), co6 = mul(2, mul(b / c, tmp.f));
ret.g = add(co1, add(co2, add(co3, add(co4, add(co5, co6)))));
co1 = tmp.h, co2 = mul(a / c, mul(n, mul(n + 1, mul(2 * n + 1, inv6))));
co3 = mul(b / c, mul(n, mul(n + 1, inv2)));
ret.h = add(co1, add(co2, co3));
} else {
ll m = (a * n + b) / c;
node tmp = solve(m - 1, c, c - b - 1, a);
co1 = mul(m, n), co2 = mod - tmp.f;
ret.f = add(co1, co2);
co1 = mul(n, mul(m, m + 1)), co2 = mod - mul(2, tmp.h);
co3 = mod - mul(2, tmp.f), co4 = mod - ret.f;
ret.g = add(co1, add(co2, add(co3, co4)));
co1 = mul(m, mul(n, mul(n + 1, inv2))), co2 = mod - mul(tmp.g, inv2);
co3 = mod - mul(tmp.f, inv2);
ret.h = add(co1, add(co2, co3));
}
return ret;
}
int main() {
int T; scanf("%d", &T);
while (T--) {
ll n, a, b, c; scanf("%lld%lld%lld%lld", &n, &a, &b, &c);
node ans = solve(n, a, b, c);
printf("%lld %lld %lld\n", ans.f, ans.g, ans.h);
}
return 0;
}