大整数运算算法(C)

引言

本文的算法大量参考《计算机程序设计艺术》(The Art of Computer Programming)的算法,代码部分大量参考 Java 的 BigInteger 库。

为了便于理解,文中的代码为无符号的大整数,有符号的大整数可以在此基础上进行进一步封装。代码在效率上还有些许提升空间,并且不做非法输入的检查。

本文包含的算法有 大数比较大数加法大数减法大数乘法大数除法大数与字符串转化

本文的代码已经开源至我个人的仓库:https://gitee.com/oldprincess/bint

一、大整数结构

1. 大整数表示

对于一个整数,可以有多种表示方法,例如二进制、十进制、十六进制等等。定义符号 \(b\) 表示基底,二进制时 \(b=2\),十进制时 \(b=10\),那么一个整数 \(n\) 可以表示成

\[n = \sum_{i=0}^k a_i b^i, \quad a \in [0, \; b-1] \]

例如十进制数 \((1234)_{10}\) 可记为 \(1\times 10^3 + 2 \times 10^2 + 3 \times 10 + 4\),十六进制数 \((AB)_{16}\) 可记为 \(10 \times 16+ 11\)

在大数库中,采用 \(b=2^{32}\) 作为基底,一是因为此时系数项 \(a\) 的取值范围为 \([0, \; 2^{32}-1]\),刚好是一个 32 位无符号整数所表示的范围,可以有效地利用内存;二是因为现代计算机大多是 64 位系统,32 位无符号整数的运算结果能够被 64 位无符号整数表示,能够方便地获取运算时进位的数值。

在具体程序实现时,使用一段内存来表示大整数,存储整数表示时的系数 \(a_i\)。此处采取大端存储的方式,即数据高位在高地址,低位在低地址。例如 \((a_2 a_1 a_0)_b\) 中,\(a_2\) 的存储地址要高于 \(a_0\) 的地址。
这么做的优势在于,在后续整数长度扩大时,可以方便地扩充内存,而保留原始的数据内容不变。

2. 相关代码

// bint.h
#include <stdint.h>

#define BINT_SIZE 256

typedef struct {
    uint32_t value[BINT_SIZE];  // 内存,初始化为0
    int32_t dsize;              // 数据长度
} BINT;                         // 大整数结构

typedef struct {
    BINT q;    // 商
    BINT r;    // 余数
} BINT_DIV_T;  // 除法返回值

#define BINT_NEW() \
    { .value = {0}, .dsize = 0 }

二. 大整数算法

1. 大整数比大小

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{m} v_{m-1} \cdots v_0)_b\),比较大小算法为

  • \(\textbf{if} \quad n > m\)
  • \(\quad \textbf{return 1} \qquad \text{// u 大于 v}\)
  • \(\textbf{else if} \quad n < m\)
  • \(\quad \textbf{return -1} \qquad \text{// u 小于 v}\)
  • \(\textbf{else}\)
  • \(\quad \textbf{for} \quad i = n \to 0\)
  • \(\quad \quad \textbf{if} \quad u_i > v_i\)
  • \(\quad \quad \quad \textbf{return 1} \qquad \text{// u 大于 v}\)
  • \(\quad \quad \textbf{if} \quad u_i < v_i\)
  • \(\quad \quad \quad \textbf{return -1} \qquad \text{// u 小于 v}\)
  • \(\quad \textbf{return 0} \qquad \text{// u 等于 v}\)
/**
 * @brief bint_cmp 大整数比大小
 * @return 1(u>v), -1(u<v), 0(u=v)
 */
int bint_cmp(BINT u, BINT v) {
    if (u.dsize > v.dsize) {
        // u 长度大于 v
        return 1;
    } else if (u.dsize < v.dsize) {
        // u 长度小于 v
        return -1;
    } else {
        // u 和 v 长度相等
        for (int i = u.dsize; i >= 0; i--) {
            if (u.value[i] > v.value[i]) {
                return 1;
            } else if (u.value[i] < v.value[i]) {
                return -1;
            }
        }
    }
    return 0;
}

2. 大整数加法

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{n} v_{n-1} \cdots v_0)_b\),它们加法结果为 \(r=(r_{n+1} r_n r_{n-1} \cdots r_0)_b\)\(r=u+v\) 算法如下

  • \(carry \gets 0 \qquad \text{// 初始化进位}\)
  • \(\textbf{for} \quad i =0 \to n \qquad \text{// 从最低位遍历至最高位}\)
  • \(\quad r_i \gets (u_i + v_i + carry) \mod b \qquad \text{// 取结果}\)
  • \(\quad carry \gets \lfloor (u_i + v_i + carry) / b \rfloor \qquad \text{// 取进位}\)
  • \(r_{n+1} \gets carry\)

上述算法即普通的竖式加法

索引 n+1 n n-1 ... 0
0 \(u_n\) \(u_{n-1}\) ... \(u_0\)
+ 0 \(v_n\) \(v_{n-1}\) ... \(v_0\)
\(r_{n+1}\) \(r_n\) \(r_{n-1}\) ... \(r_0\)
BINT bint_add(BINT u, BINT v) {
    BINT r = BINT_NEW();
    int32_t maxlen = max(u.dsize, v.dsize);
    uint32_t carry = 0;  // 进位
    uint64_t tmp = 0;    //临时值
    for (int32_t i = 0; i < maxlen; i++) {
        tmp = (uint64_t)(u.value[i]) + (uint64_t)(v.value[i]) + carry;
        r.value[i] = (uint32_t)(tmp & UINT32_MAX);
        carry = (uint32_t)(tmp >> 32);
    }
    r.dsize = maxlen;
    // 处理结果长度
    if (carry > 0) {
        r.value[maxlen] = carry;
        r.dsize++;
    }
    return r;
}

3. 大整数减法

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{n} v_{n-1} \cdots v_0)_b\),它们减法结果为 \(r=(r_{n} r_{n-1} \cdots r_0)_b\)\(r=u-v \quad (u>v)\) 算法如下

  • \(borrow \gets 0 \qquad \text{//初始化借位}\)
  • \(\textbf{for} \quad i=0 \to n \qquad \text{// 从最低位遍历至最高位}\)
  • \(\quad \textbf{if} \quad u_i - v_i - borrow < 0 \qquad \text{// 需要借位}\)
  • \(\quad \quad r_i \gets u_i - v_i -borrow + b\)
  • \(\quad \quad borrow \gets 1 \qquad \text{// 设置借位}\)
  • \(\quad \textbf{else}\)
  • \(\quad \quad r_i \gets u_i - v_i -borrow\)
  • \(\quad \quad borrow \gets 0\)

上述算法即普通的竖式减法

索引 n n-1 ... 0
\(u_n\) \(u_{n-1}\) ... \(u_0\)
- \(v_n\) \(v_{n-1}\) ... \(v_0\)
\(r_n\) \(r_{n-1}\) ... \(r_0\)
BINT bint_sub(BINT u, BINT v) {
    BINT r = BINT_NEW();
    int32_t maxlen = max(u.dsize, v.dsize);
    uint32_t borrow = 0;  // 借位
    uint64_t tmp = 0;     //临时值
    for (int32_t i = 0; i < maxlen; i++) {
        tmp = (uint64_t)(u.value[i]) - (uint64_t)(v.value[i]) - borrow;
        r.value[i] = (uint32_t)(tmp & UINT32_MAX);
        // 存在借位
        if (tmp >> 32 != 0) {
            borrow = 1;
        } else {
            borrow = 0;
        }
    }
    r.dsize = maxlen;
    // 处理结果长度
    while (r.dsize > 0 && r.value[r.dsize - 1] == 0) {
        r.dsize--;
    }
    return r;
}

4. 大整数乘法

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{m} v_{m-1} \cdots v_0)_b\),它们乘法结果为 \(r=(r_{n+m} r_{m+n-1} r_{m+n-2} \cdots r_0)_b\)\(r=u \times v\) 算法如下

  • \(r \gets 0\)
  • \(\textbf{for} \quad i =0 \to m \qquad \text{// 遍历v}\)
  • \(\quad carry \gets 0 \qquad \text{// 初始化进位}\)
  • \(\quad \textbf{for} \quad j =0 \to n \qquad \text{// 遍历 u}\)
  • \(\quad \quad tmp \gets r_{i+j} + u_j \times v_i + carry\)
  • \(\quad \quad r_{i+j} \gets tmp \mod b\)
  • \(\quad \quad carry \gets \lfloor tmp / b \rfloor\)
  • \(\quad r_{i+n} \gets carry\)

上述算法即普通的竖式乘法

索引 ... n ... m ... 0
\(u_n\) ... \(u_m\) ... \(u_0\)
\(\times\) \(v_{m}\) ... \(v_0\)
... \(r_n\) ... \(r_{m}\) ... \(r_0\)

伪代码中的算法可以表达成公式

\[(r_{n+m} \cdots r_0)_b = \sum_{i=0}^{m} (u_n \cdots u_0)_b \times v_i \times b^{i} \]

BINT bint_mul(BINT u, BINT v) {
    BINT r = BINT_NEW();
    uint64_t tmp;
    uint32_t carry;
    for (int32_t i = 0; i < v.dsize; i++) {
        carry = 0;
        uint32_t val = v.value[i];
        for (int32_t j = 0; j < u.dsize; j++) {
            tmp = (uint64_t)(r.value[i + j]) + (uint64_t)(u.value[j]) * val +
                  carry;
            r.value[i + j] = (uint32_t)(tmp & UINT32_MAX);
            carry = (uint32_t)(tmp >> 32);
        }
        r.value[i + u.dsize] = carry;
    }
    r.dsize = u.dsize + v.dsize;
    // 处理结果长度
    while (r.dsize > 0 && r.value[r.dsize - 1] == 0) {
        r.dsize--;
    }
    return r;
}

对于大数乘法,上述只是一个比较普通的算法,在 Java BigInteger 库中还采用了分治的算法,例如 Karatsuba 或者 Toom-Cook。而且对于平方运算,还有更为高效的算法。

5. 乘法与除法(大整数与普通整数运算)

为了方面后续代码的实现,需要用到如下两个函数

(1) 乘和加

给定 \(b\) 进制大整数 \((u_n u_{n-1} \cdots u_0)_b\)\(b\) 进制整数 \(x\)\(y\),计算 \((r_{n+1} r_n \cdots r_0)_b = u \times x + y\) 的算法如下

  • \(r \gets 0\)
  • \(carry \gets y \qquad \text{// 初始化进位}\)
  • \(\textbf{for} \quad i =0 \to n \qquad \text{// 遍历 u}\)
  • \(\quad tmp \gets u_i \times x + carry\)
  • \(\quad r_i \gets tmp \mod b\)
  • \(\quad carry \gets \lfloor tmp / b \rfloor\)
  • \(r_{n+1} \gets carry\)
/**
 * @return u * mul_val + add_val
 */
static BINT bint_muladd(BINT u, uint32_t mul_val, uint32_t add_val) {
    BINT r = BINT_NEW();
    uint64_t tmp;
    uint32_t carry = add_val;
    for (int32_t i = 0; i < u.dsize; i++) {
        tmp = (uint64_t)(u.value[i]) * mul_val + carry;
        r.value[i] = (uint32_t)(tmp & UINT32_MAX);
        carry = (uint32_t)(tmp >> 32);
    }
    r.dsize = u.dsize;
    if (carry > 0) {
        r.value[r.dsize] = carry;
        r.dsize++;
    }
    return r;
}

(2) 除法

给定 \(b\) 进制大整数 \((u_n u_{n-1} \cdots u_0)_b\)\(b\) 进制整数 \(d\),计算 \((q_n q_{n-1} \cdots q_0)_b = u \div d\),余数为 \(b\) 进制整数 \(r\) 的算法如下

  • \(r \gets 0 \qquad \text{// 初始化余数}\)
  • \(\textbf{for} \quad i = n \to 0 \qquad \text{// 遍历 u}\)
  • \(\quad tmp \gets (r \times b + u_i)\)
  • \(\quad q_i \gets tmp \mod d\)
  • \(\quad r \gets \lfloor tmp / d \rfloor\)

上述即普通的竖式除法,从最高位开始除,记录当前位的商和余数,直至最低位

/**
 * @return u / div, 余数为 r
 */
static BINT bint_div_u32(uint32_t* r, BINT u, uint32_t div) {
    BINT q = BINT_NEW();
    uint64_t tmp;      //临时值
    uint64_t rem = 0;  //余数
    // 从最高位一直除至最低位
    for (int32_t i = u.dsize - 1; i >= 0; i--) {
        tmp = (rem << 32) | (uint64_t)(u.value[i]);
        q.value[i] = (uint32_t)(tmp / div);
        rem = tmp % (uint64_t)div;
    }
    if (r != NULL) {
        *r = (uint32_t)rem;
    }
    q.dsize = u.dsize;
    while (q.dsize > 0 && q.value[q.dsize - 1] == 0) {
        q.dsize--;
    }
    return q;
}

6. 字符串转化

(1) 大整数转字符串

\(b\) 进制大整数 \(u = (u_n u_{n-1} \cdots u_0)_b\) 转化为 \(radix\) 进制的字符串 \(s\) 算法如下

  • \(i \gets 0\)
  • \(\textbf{while} \quad u \neq 0\)
  • \(\quad rem \gets u \mod radix\)
  • \(\quad u \gets \lfloor u \div radix \rfloor\)
  • \(\quad s_i \gets \text{chr}(rem) \qquad \text{// 存储余数对应的字符数值}\)
  • \(\quad i \gets i + 1\)
  • \(s \gets \text{逆序}(s)\)
void bint_tostr(BINT u, char* s, int radix) {
    // 对0特殊判断
    if (u.dsize == 0) {
        s[0] = '0';
        s[1] = '\0';
        return;
    }
    // 除
    uint32_t r = 0;
    char* cur = s;
    while (u.dsize > 0) {
        u = bint_div_u32(&r, u, radix);
        if (r >= 0 && r <= 9) {
            // 0~9 直接转化为数字
            *cur = '0' + (char)r;
        } else {
            // 大于9的需要转为为字母a~z
            *cur = 'a' + (char)r - 10;
        }
        cur++;
    }
    *cur = '\0';
    // 逆序
    int32_t sLen = (int32_t)((size_t)cur - (size_t)s);
    for (int32_t i = 0; i < sLen / 2; i++) {
        char t = s[i];
        s[i] = s[sLen - 1 - i];
        s[sLen - 1 - i] = t;
    }
}

(2) 字符串转大整数

\(radix\) 进制字符串 \(s\) 转化为 \(b\) 进制大整数 \(u = (u_n u_{n-1} \cdots u_0)_b\) 算法如下

  • \(u \gets 0\)
  • \(\textbf{for} \quad i = 0 \to s.len - 1\)
  • \(\quad u \gets u \times radix + \text{int}(s_i)\)

为了有效地利用运算模块,可以将 \(k\) 个字符合并为一组(这也是Java BigInteger库所采用的策略),即

  • \(u \gets 0\)
  • \(i \gets 0\)
  • \(\textbf{while} \quad i < s.len\)
  • \(\quad u \gets u \times radix + \text{int}(s_i \cdots s_{i+k-1})\)
  • \(\quad i \gets i+k\)
BINT bint_fromstr(char* s, int radix) {
    // 对于 radix进制,以 parse_span[radix] 个字符为单位进行转化
    static int8_t parse_span[] = {0, 0, 30, 19, 15, 13, 11, 11, 10, 9, 9, 8, 8,
                                  8, 8, 7,  7,  7,  7,  7,  7,  7,  6, 6, 6, 6,
                                  6, 6, 6,  6,  6,  6,  6,  6,  6,  6, 5};

    uint32_t add_val;
    uint32_t mul_val;
    BINT r = BINT_NEW();
    // 进制转换
    while (*s != '\0') {
        // 初始化变量
        add_val = 0;
        mul_val = 1;
        // 进制转换,parse_span[radix] 个字符为一组
        for (int8_t i = 0; i < parse_span[radix]; i++) {
            uint32_t tmp = 0;
            char c = *s;
            // End of String
            if (c == '\0') {
                break;
            }
            // 字符转整数
            if ('0' <= c && c <= '9') {
                // 0~9
                tmp = c - '0';
            } else if ('A' <= c && c <= 'A' + radix - 10 - 1) {
                // A~Z
                tmp = c - 'A' + 10;
            } else if ('a' <= c && c <= 'a' + radix - 10 - 1) {
                // a~z
                tmp = c - 'a' + 10;
            }
            mul_val = mul_val * radix;        // 更新乘数
            add_val = add_val * radix + tmp;  // 更新加数
            s++;
        }
        r = bint_muladd(r, mul_val, add_val);
    }
    return r;
}

7. 除法

除法算法基于《计算机程序设计艺术》书中提到的一些定理,具体证明建议去阅读书籍

(1) 定理1

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{n-1} v_{n-2} \cdots v_0)_b\),它们除法结果为 \(q\)

通过竖式除法可知,若 \((u_{n}u_{n-1}\cdots u_1)_b < (v_{n-1} v_{n-2} \cdots v_0)_b\),那么 \(q\) 的长度为 1

\[\hat{q} = \text{min}(\lfloor \frac{u_n b + u_{n-1}}{v_{n-1}} \rfloor, b - 1) \]

如果有 \(v_{n-1} \geq \lfloor b/2 \rfloor\),那么 \(\hat{q} -2 \leq q \leq \hat{q}\)

(2) 定理2

对于定理1中的 \(u,v,\hat{q}\),令 \(\hat{r} = (u_n b + u_{n-1}) \mod v_{n-1}\)

测试 \(\hat{q} = b\)\(\hat{q} v_{n-2} > b \hat{r} + u_{n-2}\),如果是,则 \(\hat{q}\) 减1,\(\hat{r}\) 加上 \(v_{n-1}\),如果 \(\hat{r} < b\) 则重复此测试

通过上述测试可以高速确定 \(\hat{q}\)\(q\) 大 1 的大多数情况,且消除 \(\hat{q}\)\(q\) 大 2 的所有情况

(3) 除法算法

对于两个 \(b\) 进制大整数 \(u=(u_{n}u_{n-1}\cdots u_0)_b\)\(v=(v_{m} v_{m-1} \cdots v_0)_b\),它们除法结果为 \(q=(q_{n-m} q_{n-m-1} \cdots q_0)_b\),余数为 \(r=(r_m r_{m-1} \cdots r_0)_b\)

在《计算机程序设计艺术》中,除法算法分为 D1 ~ D8 共 8 步,下面的伪代码按照书中给的算法顺序

  • \(\text{// === D1 规格化 ===}\)
  • \(d \gets \lfloor b / (v_m + 1) \rfloor \qquad \text{//为了让}v_m \geq \lfloor b/2 \rfloor\)
  • \(u \gets u \times d\)
  • \(v \gets v \times d\)
  • \(u_{u.len} \gets 0 \qquad \text{// 将 u 的最高位的下一位置0}\)
  • \(\text{// === D2 初始化 j ===}\)
  • \(\textbf{for} \quad j=u.len - v.len \to 0\)
  • \(\quad \text{// === D3 计算 q ===}\)
  • \(\quad \hat{q} \gets \lfloor(u_{j+m} b + u_{j+m-1}) / v_{m} \rfloor\)
  • \(\quad \textbf{if} \quad \hat{q} = 0\)
  • \(\quad \quad q_j \gets 0\)
  • \(\quad \quad \textbf{continue} \qquad \text{// 跳过本轮}\)
  • \(\quad \textbf{if} \quad \hat{q} \ge b\)
  • \(\quad \quad \hat{q} \gets b - 1\)
  • \(\quad \hat{r} \gets (u_{j+m} b + u_{j+m-1}) - v_m \times \hat{q}\)
  • \(\quad \text{// 测试}\hat{q}\)
  • \(\quad \textbf{while} \quad \hat{q} \times v_{m-1} > \hat{r} \times b + u_{j+m-2}\)
  • \(\quad \quad \hat{q} \gets \hat{q} - 1 \qquad \text{// 更新}\hat{q}\)
  • \(\quad \quad \hat{r} \gets \hat{r} + v_m \qquad \text{// 更新}\hat{r}\)
  • \(\quad \quad \quad \textbf{if} \quad \hat{r} \ge b\)
  • \(\quad \quad \quad \quad \textbf{break} \qquad \text{// 跳出while循环}\)
  • \(\quad \text{// === D4 乘和减 ===}\)
  • \(\quad u \gets u - q \times (v \times b^{j})\)
  • \(\quad \text{// === D5 测试余数 ===}\)
  • \(\quad \textbf{if} \quad u < 0\)
  • \(\quad \quad \text{// === D6 往回加 ===}\)
  • \(\quad \quad u \gets u + (v \times b^{j})\)
  • \(\quad \quad \hat{q} \gets \hat{q} - 1\)
  • \(\quad q_j \gets \hat{q}\)
  • \(\quad \text{// === D7 对 j 进行循环 ===}\)
  • \(\text{// === D8 逆规格化 ===}\)
  • \(r \gets u \div d\)

(4) 除法代码

代码对算法中的 \(d \gets \lfloor b / (v_m + 1) \rfloor\) 进行了修改,通过计算 \(v_m\) 的最高有效位来决定 \(d\) 的值,对于 \(u \gets u \times d\) 操作可采用大整数的移位运算实现,这在二进制计算机中具备更高效的效率(为了便于理解,下方代码中并未使用移位,而是采用了较为低效的乘法操作)

static int significant_bit(uint32_t n) {
    int bit = 0;
    // 判断最高有效位
    while (n != 0) {
        n >>= 1;
        bit++;
    }
    return 32 - bit;
}

BINT_DIV_T bint_div(BINT u, BINT v) {
    BINT_DIV_T res = {BINT_NEW(), BINT_NEW()};
    // 除数为0
    if (v.dsize == 0) {
        // 直接返回
        return res;
    }
    // 除数长度为 1(选择更高效的算法)
    if (v.dsize == 1) {
        uint32_t divisor = v.value[0];  //除数
        uint32_t rem = 0;               //余数
        res.q = bint_div_u32(&rem, u, divisor);
        res.r.value[0] = rem;
        if (rem > 0) {
            res.r.dsize++;
        }
        return res;
    }
    // Knuth 除法

    // D1 规格化
    int d = significant_bit(v.value[v.dsize - 1]);
    u = bint_muladd(u, 1 << d, 0);
    v = bint_muladd(v, 1 << d, 0);
    u.value[u.dsize] = 0;  // 为了方便 D2 的循环迭代
    // D2 初始化 j
    int32_t div_len = v.dsize;
    uint32_t div_h = v.value[div_len - 1];
    uint32_t div_l = v.value[div_len - 2];
    uint64_t base = (uint64_t)UINT32_MAX + 1;  // 2^32
    for (int32_t j = u.dsize - div_len; j >= 0; j--) {
        // D3 计算qhat
        uint64_t qhat, rhat;
        uint64_t uh = (uint64_t)(u.value[j + div_len]);
        uint64_t ul = (uint64_t)(u.value[j + div_len - 1]);
        uint64_t ul2 = (uint64_t)(u.value[j + div_len - 2]);

        qhat = (uh * base + ul) / (uint64_t)div_h;
        if (qhat > UINT32_MAX) {
            // 防止计算出的qhat过大
            qhat = UINT32_MAX;
        }
        rhat = (uh * base + ul) - (uint64_t)div_h * qhat;
        if (qhat == 0) {
            // 商为0,跳过本轮
            res.q.value[j] = 0;
            continue;
        }
        while (qhat * (uint64_t)div_l > base * rhat + ul2) {
            // 调整qhat
            qhat--;
            rhat += div_h;
            if (rhat >= base) {
                break;
            }
        }
        // D4 乘和减
        uint64_t tmp = 0;     //临时值
        uint64_t borrow = 0;  // u的借位
        uint64_t carry = 0;   // div的进位
        for (int32_t i = 0; i < div_len; i++) {
            // 计算乘法 div*qhat
            uint64_t t = qhat * (uint64_t)(v.value[i]) + carry;
            carry = t >> 32;
            // 计算减法 u - div*qhat
            tmp = (uint64_t)(u.value[j + i]) - (t & UINT32_MAX) - borrow;
            borrow = (tmp >> 32) ? 1 : 0;
            // 赋值
            u.value[j + i] = (uint32_t)(tmp & UINT32_MAX);
        }
        if (borrow != 0 || carry != 0) {
            tmp = (uint64_t)(u.value[j + div_len]) - carry - borrow;
            borrow = (tmp >> 32) ? 1 : 0;
            u.value[j + div_len] = (uint32_t)(tmp & UINT32_MAX);
        }
        // D5 测试余数
        res.q.value[j] = (borrow == 0) ? (uint32_t)qhat : (uint32_t)qhat - 1;
        // D6 往回加
        if (borrow != 0) {
            uint64_t carry = 0;  // 加法进位
            for (int32_t i = 0; i < div_len; i++) {
                tmp = (uint64_t)u.value[j + i] + (uint64_t)(v.value[i]) + carry;
                u.value[j + i] = (uint32_t)(tmp & UINT32_MAX);
                carry = tmp >> 32;
            }
            if (carry != 0) {
                tmp = (uint64_t)u.value[j + div_len] + carry;
                u.value[j + div_len] = (uint32_t)(tmp & UINT32_MAX);
                // 将之后的进位丢弃,以抵消D4中的借位
            }
        }

    }  // D7 对 j 进行循环

    res.q.dsize = u.dsize - div_len + 1;
    while (res.q.dsize > 0 && res.q.value[res.q.dsize - 1] == 0) {
        res.q.dsize--;
    }
    res.r = bint_div_u32(NULL, u, 1 << d);

    return res;
}

8. 代码运行结果

上述代码的调用方式大致如下

#include <stdio.h>
#include <stdlib.h>
#include "bint.h"

int main() {
    BINT a, b, c;
    BINT_DIV_T d;
    char sbuffer[1024], sbuffer2[1024];
    // load from str
    a = bint_fromstr("1234567123456712345671234567", 10);
    b = bint_fromstr("654321654321654321654321", 10);

    c = bint_add(a, b);  // a + b
    bint_tostr(c, sbuffer, 10);
    // a+b=1235221445111033999992888888
    printf("a+b=%s\n", sbuffer);

    c = bint_sub(a, b);  // a - b
    bint_tostr(c, sbuffer, 10);
    // a-b=1233912801802390691349580246
    printf("a-b=%s\n", sbuffer);

    c = bint_mul(a, b);  // a * b
    bint_tostr(c, sbuffer, 10);
    // a*b=807804002591322070054017119327931540612061880114007
    printf("a*b=%s\n", sbuffer);

    d = bint_div(a, b);  // a / b
    bint_tostr(d.q, sbuffer, 10);
    bint_tostr(d.r, sbuffer2, 10);
    // a/b=1886...516483406072295031185161
    printf("a/b=%s...%s\n", sbuffer, sbuffer2);

    bint_tostr(a, sbuffer, 10);  // to str
    bint_tostr(a, sbuffer2, 16);
    // a=1234567123456712345671234567(10)
    // a=3fd35c1ddd60c78fbb0f407(16)
    printf("a=%s(10)\na=%s(16)\n", sbuffer, sbuffer2);

    // load from hex str
    a = bint_fromstr("3fd35c1ddd60c78fbb0f407", 16);
    bint_tostr(a, sbuffer, 10);
    // a=1234567123456712345671234567(10)
    printf("a=%s(10)\n", sbuffer);

    return 0;
}

三、 其它

上述给出的代码其实还有较大的优化空间

  • 为了便于理解并没有使用动态内存分配,故当数据长度超过数组范围时就会发生溢出
  • 存在更为高效的乘法算法(分治),对于平方还能进一步优化
  • 没有设计左右移位的算法,在乘或除2的幂次方数据时,采用移位高效得多得多
  • 没有进行非法输入检测
  • 函数传参可以使用指针,而不是直接传递结构体。在Openssl中,为了降低运算时临时数据内存分配的开销,额外设置了一个用于存储上下文CTX的结构

参考:《计算机程序设计艺术》,Donald E. Knuth

posted @ 2022-04-22 22:37  kentle  阅读(1087)  评论(0编辑  收藏  举报