模型加密杂谈:TensorRT加密

在大多数项目交付场景中,经常需要对部署模型进行加密。模型加密一方面可以防止泄密,一方面可以便于模型跟踪管理,防止混淆。

由于博主使用的部署模型多为TensorRT格式,这里以TensorRT模型为例,讲解如何对模型进行加密、解密以及推理加密模型。

代码仓库:https://github.com/laugh12321/model_crypto

加密算法的选择和支持的库

Crypto++ 是C/C++的加密算法库,基本上涵盖了市面上的各类加密解密算法,包括对称加密算法(AES等)和非对称加密算法(RSA等)。

两种算法使用的场景不同,非对称加密算法一般应用于数字签名和密钥协商的场景下,而对称加密算法一般应用于纯数据加密场景,性能更优。在对模型的加密过程中使用对称加密算法。

简易版本

以AES-GCM加密模式为例,编写一个检测的加密、解密方法

std::string Encrypt(const std::string& data, const CryptoPP::SecByteBlock& key, const CryptoPP::SecByteBlock& iv) {
    std::string cipher;

    try {
        CryptoPP::GCM<CryptoPP::AES>::Encryption e;
        e.SetKeyWithIV(key, key.size(), iv, iv.size());

        CryptoPP::StringSource(data, true, 
            new CryptoPP::AuthenticatedEncryptionFilter(e,
                new CryptoPP::StringSink(cipher)
            ) // StreamTransformationFilter
        ); // StringSource
    }
    catch(const CryptoPP::Exception& e) {
        std::cerr << e.what() << std::endl;
        exit(1);
    }
    return cipher;
}

std::string Decrypt(const std::string& cipher, const CryptoPP::SecByteBlock& key, const CryptoPP::SecByteBlock& iv) {
    std::string recovered;

    try {
        CryptoPP::GCM<CryptoPP::AES>::Decryption d;
        d.SetKeyWithIV(key, key.size(), iv, iv.size());

        // The StreamTransformationFilter removes
        //  padding as required.
        CryptoPP::StringSource(cipher, true, 
            new CryptoPP::AuthenticatedDecryptionFilter(d,
                new CryptoPP::StringSink(recovered)
            ) // StreamTransformationFilter
        ); // StringSource

    }
    catch(const CryptoPP::Exception& e) {
        std::cerr << e.what() << std::endl;
        exit(1);
    }
    return recovered;
}

上述代码,使用AES-GCM加密模式对数据进行加密、解密,其中key和iv为加密算法的参数,keySize为key的长度。

加密流程:

  1. 初始化加密器,设置key和iv
  2. 读取文件内容并存储在字符串data中
  3. 使用加密器对data进行加密,加密后的内容存储在字符串cipher中

解密流程:

  1. 初始化解密器,设置key和iv
  2. 读取加密后的文件内容并存储在字符串cipher中
  3. 使用解密器对cipher进行解密,解密后的内容存储在字符串recovered中

转换为序列化格式

在推理加密模型时为了防止解密模型暴露在外部,这里首先读取加密模型的二进制数据,其次对二进制数据进行解密,然后将解密后的数据转换为序列化格式,最后通过runtime->deserializeCudaEngine加载引擎数据。

// 打开文件 "engine.trt" 以二进制模式进行读取
std::ifstream engineFile("engine.trt", std::ios::binary);
if (engineFile) {
    // 将文件指针移动到文件的末尾
    engineFile.seekg(0, std::ios::end); 
    // 获取文件的大小(以字节为单位)
    std::streampos fsize = engineFile.tellg();
    // 将文件指针移动回文件的开头
    engineFile.seekg(0, std::ios::beg);

    // 创建一个字符串对象以存储引擎数据(密文),初始化为0
    std::string engineCipher(fsize, 0);

    // 从文件中读取引擎数据(密文)到 engineCipher
    engineFile.read(&engineCipher[0], fsize);
    engineFile.close();

    // 解密引擎数据
    std::string engineRecover = Decrypt(engineCipher, key, iv);

    // 将解密后的数据转换为一个字节向量
    std::vector<unsigned char> engineData(engineRecover.begin(), engineRecover.end());

    // 创建 TensorRT 运行时(Runtime)对象
    std::unique_ptr<nvinfer1::IRuntime> runtime{nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger())};

    // 使用 TensorRT runtime 对象加载引擎数据
    std::unique_ptr<nvinfer1::ICudaEngine> mEngine(runtime->deserializeCudaEngine(engineData.data(), engineData.size(), nullptr));
}

使用MAC地址作为密钥

一般情况下,我们只想让客户在指定的机器上运行模型,这时候就需要使用机器的唯一标识作为密钥,这里使用MAC地址作为密钥。

#pragma once

#include <string>

#ifdef _WIN32
#include <windows.h>
#include <iphlpapi.h>
#pragma comment(lib, "IPHLPAPI.lib")
#else
#include <ifaddrs.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#endif

inline std::string GetMACAddress() {
    std::string macAddress = "";

#ifdef _WIN32
    IP_ADAPTER_INFO* adapterInfo = nullptr;
    ULONG bufferSize = 0;

    if (GetAdaptersInfo(adapterInfo, &bufferSize) == ERROR_BUFFER_OVERFLOW) {
        adapterInfo = new IP_ADAPTER_INFO[bufferSize / sizeof(IP_ADAPTER_INFO)];
        if (GetAdaptersInfo(adapterInfo, &bufferSize) == ERROR_SUCCESS) {
            for (PIP_ADAPTER_INFO adapter = adapterInfo; adapter; adapter = adapter->Next) {
                if (adapter->Type == MIB_IF_TYPE_ETHERNET) {
                    char mac[18];
                    snprintf(mac, sizeof(mac), "%.2X:%.2X:%.2X:%.2X:%.2X:%.2X",
                        adapter->Address[0], adapter->Address[1], adapter->Address[2],
                        adapter->Address[3], adapter->Address[4], adapter->Address[5]);
                    macAddress = mac;
                    break;
                }
            }
        }
        delete[] adapterInfo;
    }
#else
    struct ifaddrs *ifaddr, *ifa;
    if (getifaddrs(&ifaddr) != -1) {
        for (ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) {
            if (ifa->ifa_addr == nullptr || ifa->ifa_addr->sa_family != AF_PACKET) {
                continue;
            }

            struct sockaddr_ll *s = reinterpret_cast<struct sockaddr_ll*>(ifa->ifa_addr);
            char mac[18];
            snprintf(mac, sizeof(mac), "%.2X:%.2X:%.2X:%.2X:%.2X:%.2X",
                s->sll_addr[0], s->sll_addr[1], s->sll_addr[2],
                s->sll_addr[3], s->sll_addr[4], s->sll_addr[5]);
            macAddress = mac;
            break;
        }
        freeifaddrs(ifaddr);
    }
#endif

    return macAddress;
}

上述代码,使用了不同的API获取MAC地址,其中Windows使用GetAdaptersInfo函数,Linux使用getifaddrs函数。

有了MAC地址,就可以将MAC地址转换为密钥,这里使用SHA256算法对MAC地址进行哈希,然后取前32个字节作为密钥。

std::string GenerateAESKey(const std::string& content) {
    CryptoPP::byte hash[CryptoPP::SHA256::DIGESTSIZE];
    CryptoPP::SHA256().CalculateDigest(hash, reinterpret_cast<const CryptoPP::byte*>(content.c_str()), content.length());

    CryptoPP::HexEncoder encoder;
    std::string encodedHash;
    encoder.Attach(new CryptoPP::StringSink(encodedHash));
    encoder.Put(hash, sizeof(hash));
    encoder.MessageEnd();

    return encodedHash.substr(0, 32); // AES-256 key length is 32 bytes
}

添加头部信息

为了新的文件能够被区分和可迭代,除了加密后的数据外还添加了头部信息,比如为了判断该文件类型使用固定的魔数作为文件的开头;为了便于后面需求迭代写入版本号以示区别;为了能够在解密时判断是否采用了相同的密钥将加密时的密钥进行SHA256计算后存储;这三部分构成了目前加密后文件的头部信息。加密后的文件包含头部信息 + 密文信息。

// Encrypt function with header information
std::string EncryptWithHeader(const std::string &data, const std::string &keyStr, const std::string &ivStr, const std::string &magicNumber, const std::string &version) {
    std::string cipher;

    try {
        CryptoPP::SecByteBlock key(reinterpret_cast<const CryptoPP::byte *>(keyStr.data()), keyStr.size());
        CryptoPP::SecByteBlock iv(reinterpret_cast<const CryptoPP::byte *>(ivStr.data()), ivStr.size());
        CryptoPP::GCM<CryptoPP::AES>::Encryption e;
        e.SetKeyWithIV(key, key.size(), iv, iv.size());

        std::string keyHash = CalculateSHA256(key);

        const size_t magicNumberLength = magicNumber.length();
        const size_t versionLength = version.length();
        const size_t keyHashLength = keyHash.length();

        // Construct the header
        std::string header;
        header.reserve(MAGIC_NUMBER_LEN + VERSION_NUMBER_LEN + KEY_HASH_LENGTH_LEN);
        header += ConvertToString(magicNumberLength, MAGIC_NUMBER_LEN);
        header += magicNumber;
        header += ConvertToString(versionLength, VERSION_NUMBER_LEN);
        header += version;
        header += ConvertToString(keyHashLength, KEY_HASH_LENGTH_LEN);
        header += keyHash;

        // Encrypt the data
        CryptoPP::StringSource s(data, true,
            new CryptoPP::AuthenticatedEncryptionFilter(e,
                new CryptoPP::StringSink(cipher)
            )
        );

        // Prepend the header to the cipher
        cipher = header + cipher;
    } catch(const CryptoPP::Exception& e) {
        // Handle the exception gracefully, e.g., log an error message
        std::cerr << "Encryption error: " << e.what() << std::endl;
    }

    return cipher;
}


// Decrypt function for header-aware encrypted data
std::string DecryptWithHeader(const std::string &cipher, const std::string &keyStr, const std::string &ivStr) {
    std::string recovered;

    try {
        CryptoPP::SecByteBlock key(reinterpret_cast<const CryptoPP::byte *>(keyStr.data()), keyStr.size());
        CryptoPP::SecByteBlock iv(reinterpret_cast<const CryptoPP::byte *>(ivStr.data()), ivStr.size());
        CryptoPP::GCM<CryptoPP::AES>::Decryption d;
        d.SetKeyWithIV(key, key.size(), iv, iv.size());

        std::string magicNumberLengthStr = cipher.substr(0, MAGIC_NUMBER_LEN);
        size_t magicNumberLength = std::stoull(magicNumberLengthStr);
        std::string magicNumber = cipher.substr(MAGIC_NUMBER_LEN, magicNumberLength);

        std::string versionLengthStr = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength, VERSION_NUMBER_LEN);
        size_t versionLength = std::stoull(versionLengthStr);
        std::string version = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN, versionLength);

        std::string keyHashLengthStr = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength, KEY_HASH_LENGTH_LEN);
        size_t keyHashLength = std::stoull(keyHashLengthStr);
        std::string keyHash = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength + KEY_HASH_LENGTH_LEN, keyHashLength);

        std::string encryptedData = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength + KEY_HASH_LENGTH_LEN + keyHashLength);

        if (!VerifyKey(key, keyHash)) {
            std::cerr << "Key verification failed." << std::endl;
            return "";
        }

        CryptoPP::StringSource(encryptedData, true,
            new CryptoPP::AuthenticatedDecryptionFilter(d,
                new CryptoPP::StringSink(recovered)
            )
        );
    } catch(const CryptoPP::Exception& e) {
        // Handle the exception gracefully, e.g., log an error message
        std::cerr << "Decryption error: " << e.what() << std::endl;
    }

    return recovered;
}

上述代码中,加密函数EncryptWithHeader中,首先计算密钥的SHA256哈希值,然后将魔数、版本号、密钥哈希值、密文依次拼接,然后使用AES-256-GCM算法加密,最后将头部信息和密文拼接返回。解密函数DecryptWithHeader中,首先从密文中提取出头部信息,然后使用密钥哈希值验证密钥是否正确,最后使用AES-256-GCM算法解密密文。

参考资料

posted @ 2023-08-09 17:38  laugh12321  阅读(2083)  评论(0编辑  收藏  举报