RSA

RSA,Rivest-Shamir-Adleman 算法,是一个常见的非对称加密算法。本文将简明扼要通俗易懂地介绍 RSA 的原理,并给出 Python 实现。

本文同步发表于我的博客 https://clouder0.com/zh-cn/posts/rsa/

Why we need RSA?

加密的需求大家都很熟悉,但非对称加密呢?我们为什么需要非对称加密?

想象以下的场景:

  • 数字签名: 你希望以你的名义发送一封邮件(或者发布任何的数据内容),并且向公众公开。

    • 但是,你不希望有其他人能够以你的名义发布内容。
    • 这个时候,公众可以解密,你可以加密。你持有私钥,公众持有公钥。私钥签名,公钥验签。
  • 加密通讯: 你建了一个小网站,你希望用户的数据在传输到你的网站的过程中是安全的。也就是说,用户发送加密后的数据,你解密数据。

    • 这个时候,用户持有公钥进行加密,你持有私钥进行解密。
    • 任何人都可以向你发送加密内容,但只有持有私钥的你才能解密。

这是非对称加密的两大经典应用场景。

如果只才能对称加密的话,这两种场景都是无法实现的:数字签名自然不必说,加密通讯的话,由于你需要让公众具有加密的能力,但又不希望他们能够解密,自然也需要非对称。

How is this possible?为什么还能做出这种「有两种密码,一个加密一个解密」的神奇算法呢?

这一般都是利用了非对称性。例如:将两个质数乘起来得到结果是简单的,但想要对某个大数做质因数分解,复杂度则极其高。

How RSA works...

RSA 利用的就是质因数分解复杂度的非对称性。

我们首先选择两个足够大的质数,记为 \(p,q\),然后:

  • 计算 \(n=pq\),这将用作模数。\(n\) 的长度被称为 key length,现在比较流行的是 512/2048/4096bits.

  • 计算 \(\lambda(n)\),即 \(n\) 的 Carmichael's totient function. 也就是找到最小的正整数 \(m\) 使得 \(a^m \equiv 1 \pmod n\) 对任意 \(a\) 都成立。

    • 在原教旨 RSA 中,选取的 \(\lambda(n)\) 实际上是 Euler's totient function,不是最小的正数 \(m\),会比较大,但也是能用。
    • 现在一般选用 \(\lambda(n) = \operatorname{lcm}(p-1,q-1)\). 证明等会再说。
  • 选择一个整数 \(e\),满足 \(e \in (1,\lambda(n))\)\(\gcd(e, \lambda(n)) = 1\).

    • \(e\) 最好具有较短的 bit-length 和较小的 hamming weight,大家经常选用的是 \(e=2^{16}+1= 65537\).
  • 计算 \(d \equiv e^{-1} \pmod {\lambda(n)}\),也就是 \(e\) 的乘法逆元。

好的,密钥生成部分结束了。

接下来,我们将 \((n, e)\) 作为加密密钥,\((n,d)\) 作为解密密钥。而剩下的 \(p,q,\lambda(n)\) 应当保密或者直接扔掉。


然后就是加密了,加密相当的简单啊,加入我们想要传递原文 \(M\),首先使用 padding 将其变成 \(m\),满足 \(0 \le m < n\). 这里的 padding 只要是一种可逆的变换就行了。

然后计算:\(c \equiv m^e \pmod n\),这里的 \(c\) 就是我们的加密结果了。

使用快速幂,可以在较短的时间内完成计算。


解密也相当的简单,我们持有密文 \(c\),想要获得 padded 后的原文 \(m\),那么:

\[c^d \equiv (m^e)^d \equiv m \pmod n \]

这里利用的核心原理是:\(ed \equiv 1 \pmod n\),实际上这就是 \(d\) 的定义式。

相信大家已经完全理解 RSA 了,笑。

Math behind the scene

让我们思考一下,RSA 算法的执行流程已经讲完了,但它为什么能保证安全性、为什么能保证正确性呢?

RSA 的核心原理是:\(e\)\(d\) 只有一个公开。而 \(m^{ed} \equiv m^{\lambda(n)} \equiv m \pmod n\).
这里 \(ed \equiv \lambda(n) \pmod n\) 就是解密密钥 \(d\) 的定义式。而 \(m^{\lambda(n)} \equiv m \pmod n\) 就是 \(\lambda(n)\) 的定义式。

实际上,\(e\)\(d\) 是相当对称的。假如持有 \(e\) 进行加密,加密后 \(c=m^e\),则 \(c^d \equiv m\). 用 \(d\) 加密也是一样的:\(c=m^d, c^e = m\).

也就是——实际上持有公钥的用户也可以既加密、又解密……吗?比如我们原本约定好公钥加密,私钥解密,that's fine. 但哪天你抽风了说我们换换位置,公钥解密私钥加密,那也是无缝切换。

当然,工程实践上公钥经常取固定的 \(e=65537\),嘛。


接下来还有一个问题,\(\lambda(n) = \operatorname{lcm}(p-1,q-1)\),为什么就有 \(m^{\lambda(n)} \equiv m \pmod n\)

根据众所周知的费马小定理,我们知道:当 \(p\) 为素数时,\(a^{p-1} \equiv 1 \pmod p\).

而当 \(n=pq\) 时,显然 \(n\) 就不是素数了,我们要找到 \(a^{\lambda(n)} \equiv 1 \pmod n\),这个时候可以使用欧拉定理:

\[a^b \equiv \begin{cases} a^{b \bmod \varphi(p)},b < \varphi(p) \\ a^{b \bmod \varphi(p) + \varphi(p)},b \geq \varphi(p) \end{cases} \pmod{p} \]

其中 \(\varphi(p)\) 为欧拉函数。欧拉函数满足积性,也就是 \(\varphi(pq) = \varphi(p) \times \varphi(q)\). 并且有对于素数 \(p\)\(\varphi(p) = p-1\).

那么 \(\varphi(n) = \varphi(pq) = \varphi(p) \times \varphi(q) = (p-1)(q-1)\),非常 reasonable.

那么显然,我们就可以得到:

\[a^{\varphi(n)} \equiv a^0 \equiv 1 \pmod n \]

这就算是求出了一个满足需要的 \(\lambda(n)\)...了吗?注意到我们的定义是最小的 \(m\) 使得 \(a^m \equiv 1 \pmod n\),这里的 \(\varphi(n)\) 未必是最小的。

当然,实际上是不是最小的其实对 RSA 影响不大。

接下来就是 Carmichael function,其计算如下:

\[\lambda(n) = \begin{cases} \varphi(n), &\text{if } n \text{ is }1,2,3,4 \text{ or an odd prime power} \\ \dfrac{1}{2}\varphi(n), &\text{ if } n = 2^r, r \ge 3 \\ \operatorname{lcm}\left( \lambda(n_1),\cdots,\lambda(n_k) \right), &\text{ if } n = n_1n_2\cdots n_k, \text{ where } n_i \text{ are}\\ &\text{ power of distinct prime numbers} \end{cases} \]

在这里,因为 \(n=pq\)\(p,q\) 都是质数,那么 \(\lambda(pq) = \operatorname{lcm}(\varphi(p),\varphi(q))= \operatorname{lcm}(p-1,q-1)\).

Implementation

涉及到大数运算,人生苦短,我用 Python.

但是 Python 确实不是很快,我决定使用稍微短一些的 pq. 1024bits 吧,这样最终的 \(n\) 就是 2048bits.

以下是核心代码:

import random


def miller_rabin(n: int, k: int):
    """use miller rabin method to test prime."""
  
    if n == 2:
        return True

    if n % 2 == 0:
        return False

    r, s = 0, n - 1
    while s % 2 == 0:
        r += 1
        s //= 2
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = pow(a, s, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True


def exgcd(a: int, b: int):
    """exgcd to cacl inverse."""
    if b == 0:
        return a, 1, 0
    d, x, y = exgcd(b, a % b)
    x, y = y, x - (a // b) * y
    return d, x, y


def is_prime(n: int) -> bool:
    return miller_rabin(n, 40)

def inv(a: int, m: int) -> int:
    """calc modular inverse."""
    d, x, y = exgcd(a, m)
    if d != 1:
        raise RuntimeError("modular inverse does not exist")
    return x % m


def rsa_encrypt(m: int, e: int, n: int) -> int:
    return pow(m, e, n)


def rsa_decrypt(c: int, d: int, n: int) -> int:
    return pow(c, d, n)


def gcd(a: int, b: int) -> int:
    if b == 0:
        return a
    return gcd(b, a % b)


def lcm(a: int, b: int) -> int:
    return a * b // gcd(a, b)


def rsa_gen(p: int, q: int) -> tuple[int, int, int]:
    n = p * q
    l = lcm(p - 1, q - 1)
    e = 65537
    d = inv(e, l)
    return n, e, d


def get_big_prime():
    while True:
        p = random.getrandbits(1024)
        if is_prime(p):
            return p


def get_pq() -> tuple[int, int]:
    return get_big_prime(), get_big_prime()

def main():
    n, e, d = rsa_gen(*get_pq())
    print("n =", n)
    print("e =", e)
    print("d =", d)
    origin = int(input("origin: "))
    c = rsa_encrypt(origin, e, n)
    print("c =", c)
    print("origin =", rsa_decrypt(c, d, n))

    assert origin == rsa_decrypt(c, d, n)

    print("OK")


if __name__ == "__main__":
    main()

一般而言,RSA 的速度较为缓慢,我们可以将 RSA 和对称加密配合使用,比如说用 RSA 传递对称加密的密钥,以实现加密通讯。

处理的长度过长的时候,需要分块。emmm,注意到计算在 \(\bmod n\) 下进行,需要分块后比 \(n\) 小。

def encrypt_file(n, e):
    with open("input.txt", "rb") as f:
        data = f.read()
    # chunking by 255 bytes
    chunks = [data[i : i + 255] for i in range(0, len(data), 255)]

    with open("output.txt", "wb") as f:
        for chunk in chunks:
            m = int.from_bytes(chunk, "little")
            c = rsa_encrypt(m, e, n).to_bytes(512, "little")
            f.write(c)


def decrypt_file(n, d):
    with open("output.txt", "rb") as f:
        data = f.read()

    chunks = [data[i : i + 512] for i in range(0, len(data), 512)]
  
    with open("output_de.txt", "wb") as f:
        for chunk in chunks[:-1]:
            c = int.from_bytes(chunk, "little")
            m = rsa_decrypt(c, d, n).to_bytes(255, "little")
            f.write(m)

        c = int.from_bytes(chunks[-1], "little")
        m = rsa_decrypt(c, d, n).to_bytes(255, "little")
        # trim trailing zeros
        while m[-1] == 0:
            m = m[:-1]
        f.write(m)
          
def test_file():
    n, e, d = rsa_gen(*get_pq())
    encrypt_file(n, e)
    decrypt_file(n, d)

posted @ 2024-10-28 19:08  *Clouder  阅读(4)  评论(0编辑  收藏  举报