SM4 SIMD 指令集优化(arm)
一、SM4简介
SM4 算法于 2012 年被国家密码管理局确定为国家密码行业标准,最初主要用于 WAPI (WLAN Authentication and Privacy Infrastructure) 无线网络中。SM4 算法的出现为将我国商用产品上的密码算法由国际标准替换为国家标准提供了强有力的支撑。随后,SM4 算法被广泛应用于政府办公、公安、银行、税务、电力等信息系统中,其在我国密码行业中占据着及其重要的位置。类似于 DES、AES 算法,SM4 算法也是一种分组密码算法。
SM4的官方文档:点击此处跳转
SM4 算法的原理可以参考该博文:点击此处跳转
二、SIMD指令简介
SIMD(Single Instruction Multiple Data)即单指令流多数据流,是一种采用一个控制器来控制多个处理器,同时对一组数据(又称“数据向量”)中的每一个分别执行相同的操作从而实现空间上的并行性的技术。简单来说就是一个指令能够同时处理多个数据。
intel 的 SIMD 指令集参考文档:点击此处跳转
arm 的 SIMD 指令集参考文档:点击此处跳转
三、SIMD 指令(arm) 优化思路
注:该方法可参考下方的现有论文,本文是对现有文章的复现:(---- 2023.05.27更新 ---)
[1] H.Kwon, H.Kim, S.Eum, et al. Optimized Implementation of SM4 on AVR Microcontrollers, RISC-V Processors, and ARM Processors[J]. IEEE Access, 2022, 10: 80225-80233.
[2] W.Guo. Efficient Constant-Time Implementation of SM4 with Intel GFNI instruction set extension and Arm NEON coprocessor[J]. Cryptology ePrint Archive, 2022.
在 ARM V8 架构中,含有32个128bit的向量寄存器(v0-v31),利用这 32 个向量寄存器,可以高效地进行数据的并行处理。
在 SM4 加解密过程中,耗时最大的部分是 S盒 查表运算,这部分涉及内存的读取,消耗最大。而在 ARM 的 NEON(SIMD指令集中的一种)中可以支持并行查表,简化运算。
3.1 查表指令介绍
ARM 中的查表指令有一个特点,需要先将表加载到寄存器中,然后才能进行查表
ARM V8 的查表指令对应的 C 语言接口为
// 表的大小为 1x128bit
uint8x16_t vqtbl1q_u8(uint8x16_t t, uint8x16_t idx);
// 表的大小为 2x128bit
uint8x16_t vqtbl2q_u8(uint8x16x2_t t, uint8x16_t idx);
// 表的大小为 3x128bit
uint8x16_t vqtbl3q_u8(uint8x16x3_t t, uint8x16_t idx);
// 表的大小为 4x128bit
uint8x16_t vqtbl4q_u8(uint8x16x4_t t, uint8x16_t idx);
文档中对查表指令的介绍为:
Table vector Lookup. This instruction reads each value from the vector elements in the index source SIMD&FP register, uses each result as an index to perform a lookup in a table of bytes that is described by one to four source table SIMD&FP registers, places the lookup result in a vector, and writes the vector to the destination SIMD&FP register. If an index is out of range for the table, the result for that lookup is 0. If more than one source register is used to describe the table, the first source register describes the lowest bytes of the table.
大概的意思如下:
向量查表。该指令读取输入的 idx 向量寄存器中的每个元素,将其作为索引,在由 1~4 个向量寄存器构成的查找表中查找,并将结果放入目的寄存器。如果索引超过表的范围,则该索引查表的结果为 0。如果使用多个寄存器描述表,则第一个寄存器描述表的低字节。
简单来说,就是使用 1~4 个 128bit 的向量寄存器构成查找表,使用另一个 128bit 寄存器作为索引。将结果存入 128bit 寄存器中。
3.2 SM4 并行查表
SM4 的 S盒 一共有 256x8bit,即需要消耗 16 个向量寄存器存储查找表
每次查表最多能使用 4 个向量寄存器作为查找表,也就是说索引的范围为 \(\left[0, \; 64 \right)\) ,所以要进行 4 轮的查表
- 将索引值在 \(\left[0, \; 64 \right)\) 的数据查表
- 将索引值在 \(\left[64, \; 128 \right)\) 的数据查表
- 将索引值在 \(\left[128, \; 196 \right)\) 的数据查表
- 将索引值在 \(\left[196, \; 256 \right)\) 的数据查表
举个例子,假设\(T_0 \to T_3\) 对应着 4 轮的查找表,\(T_i(x) = T(x+64*i),\; T\) 为 256x8bit 查找表
待查表的数据为 \([0, \; 64, \; 128, \; 196]\) (每个数据大小 8bit)
- 第一次查表,由于 \(64, 128, 196\) 超过表的范围,对应位置的查表结果为 0。那么总的查表结果为 \([T_0(0), \; 0, \; 0, \; 0]\)
- 为了进行对索引值在 \(\left[64, \; 128 \right)\) 的数据查表,将原先的索引值减去 64,即 \([196, \; 0, \; 64, \; 128]\) ,其中 \(0 - 64 = 196\) 是 8bit 数据溢出的结果。之后,再对新的索引进行查表,结果为 \([0, \; T_1(0), \; 0, \; 0]\)
- 与 2 类似,对 2 中的新索引再减去 64,又得到一个新的索引 \([128, \; 196, \; 0, \; 64]\) ,查表结果为 \([0,\; 0,\; T_2(0),\; 0]\)
- 与 3 类似,对 3 中的新索引再减去 64,又得到一个新的索引 \([64, \; 128, \; 196, \; 0]\) ,查表结果为 \([0,\; 0,\; 0,\; T_3(0)]\)
- 最后,将 4 轮查表结果合并(异或),得到最终的结果 \([T_0(0),\; T_1(0),\; T_2(0),\; T_3(0)]\) ,即 \([T(0),\; T(64),\; T(128),\; T(196)]\)
3.3 SM4 数据加载
在一轮 SM4 迭代过程中,需要查表的数据为 32bit,为了尽可能利用好 128bit 的向量寄存器,将 4 组消息的 32bit 集中与同一个向量寄存器中
即对于消息 \(A_0, A_1, A_2, A_3, B_0, \cdots, D_3\) ,将其装载成 \([A_0, B_0, C_0, D_0], \cdots [A_3, B_3, C_3, D_3]\) (每个字母代表 32bit 数)
在存储时,再转变回原先的顺序
在 ARM V8 中,有指令能够方便地进行上述操作,分别是
// 读取
uint32x4x4_t vld4q_u32 (uint32_t const * ptr);
// 储存
void vst4q_u32 (uint32_t * ptr, uint32x4x4_t val);
此外,还需要考虑字节序的问题,需要使用vrev32q_u8
指令进行 32bit 内字节逆序
四、代码实现
指令集优化主要针对加解密过程,密钥生成过程不过优化,故密钥生成算法依旧是普通的实现方法。
编译命令为gcc -o main -O3 main.c sm4.c sm4_x4.c
(需要在 armv8 架构)
4.1 sm4.h
// sm4.h
#ifndef SM4_H
#define SM4_H
#include <stdint.h>
/**
* @brief SM4 密钥
*/
typedef struct _SM4_Key {
uint32_t rk[32];//32轮密钥
} SM4_Key;
/**
* @brief 初始化 SM4 轮密钥
* @param key 128bit长度密钥
* @param sm4_key SM4 密钥
*/
void SM4_KeyInit(uint8_t* key, SM4_Key* sm4_key);
#endif
4.2 sm4.c
//sm4.c
#include "sm4.h"
static uint32_t FK[4] = {0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc};
static uint32_t CK[32] = {
0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269, 0x70777e85, 0x8c939aa1,
0xa8afb6bd, 0xc4cbd2d9, 0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249,
0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9, 0xc0c7ced5, 0xdce3eaf1,
0xf8ff060d, 0x141b2229, 0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299,
0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209, 0x10171e25, 0x2c333a41,
0x484f565d, 0x646b7279};
static uint8_t SBox[256] = {
0xD6, 0x90, 0xE9, 0xFE, 0xCC, 0xE1, 0x3D, 0xB7, 0x16, 0xB6, 0x14, 0xC2,
0x28, 0xFB, 0x2C, 0x05, 0x2B, 0x67, 0x9A, 0x76, 0x2A, 0xBE, 0x04, 0xC3,
0xAA, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99, 0x9C, 0x42, 0x50, 0xF4,
0x91, 0xEF, 0x98, 0x7A, 0x33, 0x54, 0x0B, 0x43, 0xED, 0xCF, 0xAC, 0x62,
0xE4, 0xB3, 0x1C, 0xA9, 0xC9, 0x08, 0xE8, 0x95, 0x80, 0xDF, 0x94, 0xFA,
0x75, 0x8F, 0x3F, 0xA6, 0x47, 0x07, 0xA7, 0xFC, 0xF3, 0x73, 0x17, 0xBA,
0x83, 0x59, 0x3C, 0x19, 0xE6, 0x85, 0x4F, 0xA8, 0x68, 0x6B, 0x81, 0xB2,
0x71, 0x64, 0xDA, 0x8B, 0xF8, 0xEB, 0x0F, 0x4B, 0x70, 0x56, 0x9D, 0x35,
0x1E, 0x24, 0x0E, 0x5E, 0x63, 0x58, 0xD1, 0xA2, 0x25, 0x22, 0x7C, 0x3B,
0x01, 0x21, 0x78, 0x87, 0xD4, 0x00, 0x46, 0x57, 0x9F, 0xD3, 0x27, 0x52,
0x4C, 0x36, 0x02, 0xE7, 0xA0, 0xC4, 0xC8, 0x9E, 0xEA, 0xBF, 0x8A, 0xD2,
0x40, 0xC7, 0x38, 0xB5, 0xA3, 0xF7, 0xF2, 0xCE, 0xF9, 0x61, 0x15, 0xA1,
0xE0, 0xAE, 0x5D, 0xA4, 0x9B, 0x34, 0x1A, 0x55, 0xAD, 0x93, 0x32, 0x30,
0xF5, 0x8C, 0xB1, 0xE3, 0x1D, 0xF6, 0xE2, 0x2E, 0x82, 0x66, 0xCA, 0x60,
0xC0, 0x29, 0x23, 0xAB, 0x0D, 0x53, 0x4E, 0x6F, 0xD5, 0xDB, 0x37, 0x45,
0xDE, 0xFD, 0x8E, 0x2F, 0x03, 0xFF, 0x6A, 0x72, 0x6D, 0x6C, 0x5B, 0x51,
0x8D, 0x1B, 0xAF, 0x92, 0xBB, 0xDD, 0xBC, 0x7F, 0x11, 0xD9, 0x5C, 0x41,
0x1F, 0x10, 0x5A, 0xD8, 0x0A, 0xC1, 0x31, 0x88, 0xA5, 0xCD, 0x7B, 0xBD,
0x2D, 0x74, 0xD0, 0x12, 0xB8, 0xE5, 0xB4, 0xB0, 0x89, 0x69, 0x97, 0x4A,
0x0C, 0x96, 0x77, 0x7E, 0x65, 0xB9, 0xF1, 0x09, 0xC5, 0x6E, 0xC6, 0x84,
0x18, 0xF0, 0x7D, 0xEC, 0x3A, 0xDC, 0x4D, 0x20, 0x79, 0xEE, 0x5F, 0x3E,
0xD7, 0xCB, 0x39, 0x48};
#define rotl32(value, shift) ((value << shift) | value >> (32 - shift))
void SM4_KeyInit(uint8_t* key, SM4_Key* sm4_key) {
uint32_t k[4];
uint32_t tmp;
uint8_t* tmp_ptr8 = (uint8_t*)&tmp;
// 初始化密钥
for (int i = 0; i < 4; i++) {
int j = 4 * i;
k[i] = (key[j + 0] << 24) | (key[j + 1] << 16) | (key[j + 2] << 8) |
(key[j + 3]);
k[i] = k[i] ^ FK[i];
}
// 32轮变换
for (int i = 0; i < 32; i++) {
tmp = k[1] ^ k[2] ^ k[3] ^ CK[i];
// SBox 盒变换
for (int j = 0; j < 4; j++) {
tmp_ptr8[j] = SBox[tmp_ptr8[j]];
}
// 线性变换
sm4_key->rk[i] = k[0] ^ tmp ^ rotl32(tmp, 13) ^ rotl32(tmp, 23);
// 移位
k[0] = k[1];
k[1] = k[2];
k[2] = k[3];
k[3] = sm4_key->rk[i];
}
}
4.3 sm4_x4.h
//sm4_x4.h
#ifndef SM4_X4_H
#define SM4_X4_H
#include"sm4.h"
/**
* @brief SM4 4组并行加密
* @param plaintext 128x4bit明文
* @param ciphertext 128x4bit密文
* @param sm4_key SM4密钥
*/
void SM4_Encrypt_x4(uint8_t* plaintext, uint8_t* ciphertext, SM4_Key* sm4_key);
/**
* @brief SM4 4组并行解密
* @param ciphertext 128x4bit密文
* @param plaintextt 128x4bit明文
* @param sm4_key SM4 密钥
*/
void SM4_Decrypt_x4(uint8_t* ciphertext, uint8_t* plaintext, SM4_Key* sm4_key);
#endif
4.4 sm4_x4.c
//sm4_x4.c
#include "sm4_x4.h"
#include <arm_neon.h>
static uint8_t SBox[256] = {
0xD6, 0x90, 0xE9, 0xFE, 0xCC, 0xE1, 0x3D, 0xB7, 0x16, 0xB6, 0x14, 0xC2,
0x28, 0xFB, 0x2C, 0x05, 0x2B, 0x67, 0x9A, 0x76, 0x2A, 0xBE, 0x04, 0xC3,
0xAA, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99, 0x9C, 0x42, 0x50, 0xF4,
0x91, 0xEF, 0x98, 0x7A, 0x33, 0x54, 0x0B, 0x43, 0xED, 0xCF, 0xAC, 0x62,
0xE4, 0xB3, 0x1C, 0xA9, 0xC9, 0x08, 0xE8, 0x95, 0x80, 0xDF, 0x94, 0xFA,
0x75, 0x8F, 0x3F, 0xA6, 0x47, 0x07, 0xA7, 0xFC, 0xF3, 0x73, 0x17, 0xBA,
0x83, 0x59, 0x3C, 0x19, 0xE6, 0x85, 0x4F, 0xA8, 0x68, 0x6B, 0x81, 0xB2,
0x71, 0x64, 0xDA, 0x8B, 0xF8, 0xEB, 0x0F, 0x4B, 0x70, 0x56, 0x9D, 0x35,
0x1E, 0x24, 0x0E, 0x5E, 0x63, 0x58, 0xD1, 0xA2, 0x25, 0x22, 0x7C, 0x3B,
0x01, 0x21, 0x78, 0x87, 0xD4, 0x00, 0x46, 0x57, 0x9F, 0xD3, 0x27, 0x52,
0x4C, 0x36, 0x02, 0xE7, 0xA0, 0xC4, 0xC8, 0x9E, 0xEA, 0xBF, 0x8A, 0xD2,
0x40, 0xC7, 0x38, 0xB5, 0xA3, 0xF7, 0xF2, 0xCE, 0xF9, 0x61, 0x15, 0xA1,
0xE0, 0xAE, 0x5D, 0xA4, 0x9B, 0x34, 0x1A, 0x55, 0xAD, 0x93, 0x32, 0x30,
0xF5, 0x8C, 0xB1, 0xE3, 0x1D, 0xF6, 0xE2, 0x2E, 0x82, 0x66, 0xCA, 0x60,
0xC0, 0x29, 0x23, 0xAB, 0x0D, 0x53, 0x4E, 0x6F, 0xD5, 0xDB, 0x37, 0x45,
0xDE, 0xFD, 0x8E, 0x2F, 0x03, 0xFF, 0x6A, 0x72, 0x6D, 0x6C, 0x5B, 0x51,
0x8D, 0x1B, 0xAF, 0x92, 0xBB, 0xDD, 0xBC, 0x7F, 0x11, 0xD9, 0x5C, 0x41,
0x1F, 0x10, 0x5A, 0xD8, 0x0A, 0xC1, 0x31, 0x88, 0xA5, 0xCD, 0x7B, 0xBD,
0x2D, 0x74, 0xD0, 0x12, 0xB8, 0xE5, 0xB4, 0xB0, 0x89, 0x69, 0x97, 0x4A,
0x0C, 0x96, 0x77, 0x7E, 0x65, 0xB9, 0xF1, 0x09, 0xC5, 0x6E, 0xC6, 0x84,
0x18, 0xF0, 0x7D, 0xEC, 0x3A, 0xDC, 0x4D, 0x20, 0x79, 0xEE, 0x5F, 0x3E,
0xD7, 0xCB, 0x39, 0x48};
/**
* @brief SM4 4组并行
* @param in 128x4bit输入
* @param out 128x4bit输出
* @param sm4_key SM4密钥
* @param enc 加密(0)或者解密(1)
*/
static void _SM4_do_x4(uint8_t* in, uint8_t* out, SM4_Key* sm4_key, int enc);
void SM4_Encrypt_x4(uint8_t* plaintext, uint8_t* ciphertext, SM4_Key* sm4_key) {
_SM4_do_x4(plaintext, ciphertext, sm4_key, 0);
}
void SM4_Decrypt_x4(uint8_t* ciphertext, uint8_t* plaintext, SM4_Key* sm4_key) {
_SM4_do_x4(ciphertext, plaintext, sm4_key, 1);
}
static void _SM4_do_x4(uint8_t* in, uint8_t* out, SM4_Key* sm4_key, int enc) {
const static uint8_t SubData[16] = {64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64};
uint8x16_t dec;
uint8x16x4_t Table[4]; // S表
uint32x4x4_t x; //数据
uint32x4_t tmp_32x4, tmp1_32x4, tmp2_32x4, tmp3_32x4;
uint8x16_t tmp_8x16, tmp1_8x16, tmp2_8x16;
//---------------------Load Data------------------
dec = vld1q_u8(SubData);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
Table[i].val[j] = vld1q_u8(SBox + 64 * i + 16 * j);
}
}
//加载数据,32bit一组分组
x = vld4q_u32((uint32_t*)in);
// 32bit内8bit逆序
x.val[0] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[0])));
x.val[1] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[1])));
x.val[2] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[2])));
x.val[3] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[3])));
//-------------------Loop 32-------------------
for (int i = 0; i < 32; i++) {
//----------X1 xor X2 xor X3 xor Kr------------
tmp_32x4 =
vdupq_n_u32((enc == 0) ? (sm4_key->rk[i]) : (sm4_key->rk[31 - i]));
tmp_32x4 = veorq_u32(tmp_32x4, x.val[1]);
tmp_32x4 = veorq_u32(tmp_32x4, x.val[2]);
tmp_32x4 = veorq_u32(tmp_32x4, x.val[3]);
//--------------------S Box-------------------
tmp_8x16 = vreinterpretq_u8_u32(tmp_32x4); //类型转换
tmp1_8x16 = vqtbl4q_u8(Table[0], tmp_8x16); //第一次查表
tmp_8x16 = vsubq_u8(tmp_8x16, dec);
tmp2_8x16 = vqtbl4q_u8(Table[1], tmp_8x16); //第二次查表
tmp1_8x16 = veorq_u8(tmp1_8x16, tmp2_8x16);
tmp_8x16 = vsubq_u8(tmp_8x16, dec);
tmp2_8x16 = vqtbl4q_u8(Table[2], tmp_8x16); //第三次查表
tmp1_8x16 = veorq_u8(tmp1_8x16, tmp2_8x16);
tmp_8x16 = vsubq_u8(tmp_8x16, dec);
tmp2_8x16 = vqtbl4q_u8(Table[3], tmp_8x16); //第四次查表
tmp1_8x16 = veorq_u8(tmp1_8x16, tmp2_8x16);
tmp_32x4 = vreinterpretq_u32_u8(tmp1_8x16); //类型转换
//--------------------L---------------------
x.val[0] = veorq_u32(x.val[0], tmp_32x4);
tmp1_32x4 = vshlq_n_u32(tmp_32x4, 2);
tmp2_32x4 = vshrq_n_u32(tmp_32x4, 32 - 2);
tmp3_32x4 = veorq_u32(tmp1_32x4, tmp2_32x4); //循环左移2位
x.val[0] = veorq_u32(x.val[0], tmp3_32x4);
tmp1_32x4 = vshlq_n_u32(tmp_32x4, 10);
tmp2_32x4 = vshrq_n_u32(tmp_32x4, 32 - 10);
tmp3_32x4 = veorq_u32(tmp1_32x4, tmp2_32x4); //循环左移10位
x.val[0] = veorq_u32(x.val[0], tmp3_32x4);
tmp1_32x4 = vshlq_n_u32(tmp_32x4, 18);
tmp2_32x4 = vshrq_n_u32(tmp_32x4, 32 - 18);
tmp3_32x4 = veorq_u32(tmp1_32x4, tmp2_32x4); //循环左移18位
x.val[0] = veorq_u32(x.val[0], tmp3_32x4);
tmp1_32x4 = vshlq_n_u32(tmp_32x4, 24);
tmp2_32x4 = vshrq_n_u32(tmp_32x4, 32 - 24);
tmp3_32x4 = veorq_u32(tmp1_32x4, tmp2_32x4); //循环左移24位
x.val[0] = veorq_u32(x.val[0], tmp3_32x4);
//
tmp_32x4 = x.val[0];
x.val[0] = x.val[1];
x.val[1] = x.val[2];
x.val[2] = x.val[3];
x.val[3] = tmp_32x4;
}
//逆序x
tmp_32x4 = x.val[0];
x.val[0] = x.val[3];
x.val[3] = tmp_32x4;
tmp_32x4 = x.val[1];
x.val[1] = x.val[2];
x.val[2] = tmp_32x4;
// 32bit内8bit逆序
x.val[0] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[0])));
x.val[1] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[1])));
x.val[2] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[2])));
x.val[3] = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(x.val[3])));
//存储数据
vst4q_u32((uint32_t*)out, x);
}
4.5 main.c
//main.c
#include <stdio.h>
#include "sm4.h"
#include "sm4_x4.h"
int main() {
SM4_Key sm4_key;
uint8_t key[16] = {0};
uint8_t a[16 * 4] = {0};
uint8_t b[16 * 4];
SM4_KeyInit(key, &sm4_key);
SM4_Encrypt_x4(a, b, &sm4_key);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 16; j++) {
printf("%02x ", b[16 * i + j]);
}
printf("\n");
}
return 0;
}