C++实现128位整数类

如何编写一个128位的整数

现在的大部分的计算机编程语言都包含了64位的有符号整数和无符号整数,有的甚至还提供了128位的整数和大数,比如:

  • \(C\#\) : System.Int128, System.UInt128
  • \(Rust\): i128, u128

但是在C/C++中并未发现uint128_t/int128_t,尽管在某些平台下可以看到__int128_t/__uint128_t。所以笔者在打算在本文中简单实现以上两种类型。

如何实现一个128位的整数

设计

我们假设当前的计算机是支持64位的,那么128位的整数我们可以看成是两个64位的整数拼接而成。

template <bool Signed, std::endian Endian>
struct int128_layout
{
    using lower_type = uint64_t;
    using upper_type = std::conditional_t<Signed, int64_t, uint64_t>;

    static constexpr bool lower_index = (Endian == std::endian::little);

    uint64_t m_data[2];
    
    constexpr int128_layout() = default;

    constexpr int128_layout(upper_type upper, lower_type lower) : m_data{ static_cast<uint64_t>(upper), lower } { }

    upper_type upper() const { return static_cast<upper_type>(m_data[1 - lower_index]); }

    lower_type lower() const { return static_cast<lower_type>(m_data[lower_index]); }
};

如果是小端存储,那么lower应该存储在低地址,而upper应该存储在高地址。

加法

对于无符号整数的加法,我们直接将高位和低位对应相加即可,如果低位产生了进位,我们再把高位加一即可。在这里我们不考虑高位进位的情况,即它会和其他无符号整数一样warp around。

// 对于128位的整数,我们尽可能的使用值而非引用(const&),具体原因可以参考std::string_view。
// https://quuxplusone.github.io/blog/2021/11/09/pass-string-view-by-value/
constexpr uint128 operator+(this uint128 lhs, uint128 rhs) 
{
    const auto lo = lhs.lower() + rhs.lower();
    const auto carry = (lo < lhs.lower() ? 1 : 0);
    const auto hi = lhs.upper() + rhs.upper() + carry;
    return uint128(hi, lo);
}

对于有符号整数的加法,其操作和无符号整数完全一致,我们只需要将其转化为无符号整数进行相加即可。

减法

减法的原理和加法是一样的,这里不再赘述。直接给出代码:

constexpr uint128 operator-(this uint128 lhs, uint128 rhs) 
{
    const auto lo = lhs.lower() - rhs.lower();
    const auto carry = (lo <= lhs.lower() ? 0 : 1);
    const auto hi = lhs.upper() - rhs.upper() - carry;
    return uint128(hi, lo);
}

后续如无特殊说明,则有符号整数的操作和无符号整数的操作完全一致。

乘法

乘法我们就直接按照数学上的方式来即可,当然这种做法会比较慢。我们依然以8位的整数来举例:

\[ X = 11111100_{0b} \\ Y = 00000011_{0b} \\ \]

显然,8位的数字相乘,结果应该是一个16位的数字。

\[ X * Y = 00000010'11110100_{0b} \]

\[ Res_{upper} = 00000010_{0b} \\ Res_{lower} = 11110100_{0b} \]

在C++中,乘法结果的类型应该和操作数是一致的,以上面的例子来看,结果也必然是一个8位的整数,那么超过8位的部分就会被舍弃。

我们将8位整数分成三个部分:

\[ X = 11111100_{0b} \\ 高四位: \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ H4(X) = 1111 \\ 低四位: \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ L4(X) = 1100 \\ 低四位的高两位:\ LH(X) = 11 \\ 低四位的低两位:\ LL(X) = 00 \]

同理

\[ Y = 00000011_{0b} \\ H4(Y) = 0000_{0b} \\ L4(Y) = 0011_{0b} \\ LH(Y) = 00_{0b} \\ LL(Y) = 11_{0b} \]

\(X\)\(Y\)各分成三部分之后,\(X*Y\)就可以拆成各部分乘积之和

\[ X * Y = \\ H4(X) * H4(Y) + H4(X) * L4(Y) + \\ L4(X) * H4(Y) + L4(X) * L4(Y) \ \ \ \ \ \]

显然,
\( H4(X) * H4(Y) \)会溢出,所以我们并不需要考虑其结果。

\[ H4(X) * L4(Y) + L4(X) * H4(Y) \]

结果在高位。对于

\[ L4(X) * L4(Y) \]

我们将拆分为

\[ LH(X) * LH(Y) + \\ LH(X) * LL(Y) + \\ LL(X) * LH(Y) + \\ LL(X) * LL(Y) \ \ \ \ \]

同理可知,\(LH(X)*LH(Y)\) 的结果在高4位,\(LH(X) * LL(Y)+LH(Y) * LL(X)\)的结果在高2位和低2位,\(LL(X)*LL(Y)\)的结果在低四位。

将所有的结果加在一起就可以得到最终的结果。

constexpr uint128 operator*(this uint128 lhs, uint128 rhs) 
{
    // We split uint128 into three parts: 
    // High64bit(H64), Low-High32bit(LH32) and Low-Low32bit(LL32)
    // |----------------|--------|--------|
    // 128             64        32       0
    //        H64           LH32     LL32
    //                           L64
    // A * B = 
    // H64(A) * H64(B) => Overflow
    // H64(A) * L64(B) => H64
    // L64(A) * H64(B) => H64
    // L64(A) * L64(B) =    
    //      LH32(A) * LH32(B) => H64
    //      LL32(A) * LH32(B) => H64 or L64
    //      LH32(A) * LL32(B) => H64 or L64
    //      LL32(A) * LL32(B) => L64

    constexpr uint64_t mask = 0xffffffff;   // mask low 64-bit

    const uint64_t ah = lhs.lower() >> 32;
    const uint64_t al = lhs.lower() & mask;

    const uint64_t bh = rhs.lower() >> 32;
    const uint64_t bl = rhs.lower() & mask;
    
    const auto part_hi = lhs.upper() * rhs.lower() // H64(A) * L64(B)
                        + lhs.lower() * rhs.upper() // L64(A) * H64(B)
                        + ah * bh;              // LH32(A) * LH32(B)

    const auto part_lo = al * bl; // LL32(A) * LL32(B)

    uint128 result(part_hi, part_lo);

    result += uint128(ah * bl) << 32;  // LH32(A) * LL32(B)
    result += uint128(bh * al) << 32;  // LL32(A) * LH32(B)

    return result;
}

除法

我们使用最传统的移位减法实现除法,这个方法很容易理解但是速度比较慢。对于十进制的除法,我们需要去估计下一位数字,但是对于二进制来说,如果当前的被除数大于等于平移后的除数,下一位数字就必然为1,否则为0。

static constexpr div_result<uint128> div_mod(uint128 dividend, uint128 divisor)
{
    assert(divisor != 0 && "dividend = quotient * divisor + remainder");

    // https://stackoverflow.com/questions/5386377/division-without-using
    if (divisor > dividend)
    {
        return { uint128(0), dividend };
    }

    if (divisor == dividend)
    {
        return { uint128(1), uint128(0) };
    }
    
    uint128 denominator = divisor;
    uint128 current = 1;
    uint128 answer = 0;

    // Follow may be faster.
    // const int shift = denominator.countl_zero() - dividend.countl_zero() + 1; 
    const int shift = countl_zero(denominator) - countl_zero(dividend) + 1; 
    denominator <<= shift;
    current <<= shift;

    // After loop, the current will be zero.
    for (int i = 0; i <= shift; ++i)
    {
        if (dividend >= denominator)
        {
            dividend -= denominator;
            answer |= current;
        }
        current >>= 1;
        denominator >>= 1;
    }
    
    return { answer, dividend };
}

移位

左移

有符号整数和无符号整数的左移操作时完全一致的,都是在最低为补0。假设我们需要左移N位,以16位整数为例子,我们将高8位和低8位分别表示为:

\[ Lower = 11111100_{0b} \\ Upper = 00111111_{0b} \]

\(N\)不足8位时:低位部分左移之后补0,高位部分左移之后补低位溢出的部分。

\[ N_1 = 2 \\ Lower_{new1} = 11110000_{0b} \\ Upper_{new1} = 11111111_{0b} \\ \\ \]

\(N\)超过8位时:低位部分全部变成0,高位部分在低位的基础上继续左移\(N-8\)位。

\[ N_2 = 10 \\ Lower_{new2} = 00000000_{0b} \\ Upper_{new2} = 11110000_{0b} \]

constexpr uint128 operator<<(this uint128 lhs, uint64_t amount) 
{
    assert(amount < 128 && "");

    const auto hi = lhs.upper();
    const auto lo = lhs.lower();
    
    if (amount >= 64)
    {
        return uint128(lo << (amount - 64), 0);
    }
    else if (amount > 0)
    {
        return uint128((hi << amount) | (lo >> (64 - amount)), lo << amount);
    }
    else
    {
        return lhs;
    }
}


右移

整数右移高位补符号位,对于无符号整数来说自然是补0,对于有符号整数来说,当该数为负数时补1,当该数为非负数的时候补0。假设我们需要右移\(N\)位,以16位整数为例,我们将高8位和低8位分别表示为:

\[ Lower = 11111100_{0b} \\ Upper = s0111111_{0b} \]

其中\(s\)表示最高位,对于有符号整数来说\(s\)是符号位,对于无符号整数来说没有什么特别的意义。

\(N\)不足8位时:低位部分左移之后补0,高位部分左移之后补低位溢出的部分。

\[ N_1 = 2 \\ Lower_{new1} = 11111111_{0b} \\ Upper_{new1} = sss01111_{0b} \\ \]

\(N\)超过8位时:低位部分全部变成\(s\),低位部分在高位的基础上继续左移\(N-8\)位。

\[ N_2 = 10 \\ Lower_{new2} = sss01111_{0b} \\ Upper_{new2} = ssssssss_{0b} \]


constexpr uint128 operator>>(this uint128 lhs, uint64_t amount) 
{
    assert(amount < 128 && "");

    const auto hi = lhs.upper();
    const auto lo = lhs.lower();

    if (amount >= 64)
    {
        return uint128(0, hi >> (amount - 64));
    }
    else if (amount > 0)
    {
        return uint128(hi >> amount, (lo >> amount) | (hi << (64 - amount)));
    }
    else
    {
        return lhs;
    }
}

constexpr int128 operator>>(this int128 lhs, uint64_t amount) 
{
    // We use uint64_t instead of int to make amount non-negative.
    // The result is undefined if the right operand is negative, or 
    // greater than or equal to the number of bits in the left expression's type.
    assert(amount < 128 && "");

    const auto result = static_cast<uint128<Endian>>(lhs) >> amount;
    // Right-shift on signed integral types is an arithmetic right shift, 
    // which performs sign-extension. So we must keep sign bit when shifting 
    // signed integer.
    if (signbit(lhs.upper()))
    {
        return result | (uint128<Endian>::max() << (127 - amount));
    }
    return result;
}

逻辑右移

有些编程语言中会自带逻辑右移运算符,比如\(C\#\)\(Java\)中的\(>>>\),无论对于有符号整数还是无符号整数,\(>>>\) 始终都是在高位补0。

显然对于无符号整数来说,\(>>>\)\(>>\)操作符是完全等价的,对于有符号的整数来说,我们只需要将其转化为无符号的整数进行\(>>\)操作最后再转换为有符号即可。

signed_integral operator>>>(signed_integral si, int amount) {
    using unsigned_integral = make_signed_t<signed_integral>;
    unsigned_integral ui = (unsigned_integral)si;
    si >>= amount;
    return signed_integral(si);
}

最值

对于无符号整数,当所有的bit都为1的时候取最大,当所有bit都为0的时候取最小。

static consteval uint128 max() 
{ 
    // 11111...11111
    return uint128(
        std::numeric_limits<uint64_t>::max(), 
        std::numeric_limits<uint64_t>::max()
    );
}

static consteval uint128 min() { return uint128(0, 0); }

对于有符号整数,当该数为负数时,非符号位的bit全为0取最小,反之最大。

static consteval int128 max() 
{ 
    // 0111111...111
    return int128(
        std::numeric_limits<int64_t>::max(),
        std::numeric_limits<uint64_t>::max()
    );
}

static consteval int128 min()
{
    // 1000000...000
    return int128(
        std::numeric_limits<int64_t>::min(),
        0
    );
}
posted @ 2024-05-22 22:28  鸿钧三清  Views(307)  Comments(0Edit  收藏  举报