【代码分享】使用 avx512 + 查表法,优化凯撒加密
作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢!
关于凯撒加密,具体请看:https://en.wikipedia.org/wiki/Caesar_cipher
总而言之就是玩点没什么用的小心眼,把字母的顺序变化一下。
第一版:根据业务逻辑直接实现:
void caesarEncodeV0(uint8_t* out, uint8_t* in, int len, int rot){
rot = rot % 26;
uint8_t* end = in + len;
uint32_t* line = table.Table[rot];
for (;in<end; in++, out++){
if (*in>='a' && *in<='z'){
*out = (*in - 'a' + rot)%26 + 'a';
} else if (*in>='A' && *in<='Z'){
*out = (*in - 'A' + rot)%26 + 'A';
} else {
*out = *in;
}
}
}
void testCaesar(){
const char* s = "QAULi2jah2eqSD1zQAULhuG0Qs9mhOF9TDGtFAGtFqB=";
int len = strlen(s);
uint8_t* out = malloc(len+1);
int rot = 4;
caesarEncodeV0(out, s, len, rot);
out[len] = '\0';
printf("in :%s\n", s);
printf("out:%s | caesarEncodeV0\n", out);
}
第二版:使用查表法
很明显,字符间的替换,可以预先放在一个数组里,然后查表就行了。
typedef struct{
uint32_t Table[26][256]; // uint32_t 要比 uint8_t 更好,猜测是因为字节对齐的原因
} __attribute__((packed)) CaesarTable;
// 预先计算替换规则后的结果
void initTable(uint32_t *table[26][256]){
for (int i=0; i<26; i++){
for (int j=0; j<256; j++){
if (j>='a' && j<='z'){
table[i][j] = (uint8_t)((j-'a'+i)%26 + 'a');
} else if (j>='A' && j<='Z'){
table[i][j] = (uint8_t)((j-'A'+i)%26 + 'A');
} else {
table[i][j] = j;
}
}
}
}
void caesarEncodeV1(uint8_t* out, uint8_t* in, int len, int rot, uint32_t *table[26][256]){
rot = rot % 26;
uint8_t* end = in + len;
uint32_t* line = table[rot];
for (;in<end; in++, out++){
*out = (uint8_t)line[*in]; // 直接查表得到结果
}
}
CaesarTable table;
void testCaesar(){
initTable(&table.Table);
//
const char* s = "QAULi2jah2eqSD1zQAULhuG0Qs9mhOF9TDGtFAGtFqB=";
int len = strlen(s);
uint8_t* out = malloc(len+1);
int rot = 4;
caesarEncodeV1(out, s, len, rot, &table.Table);
out[len] = '\0';
printf("in :%s\n", s);
printf("out:%s | caesarEncodeV1\n", out);
}
第三版:使用 simd,一个周期内查表多次
查看文档发现,只有 avx512 指令集才能很好的支持查表操作。
void caesarEncodeSIMDV2(uint8_t* out, uint8_t* in, int len, int rot, uint32_t *table[26][256]){
rot = rot % 26;
uint32_t* line = table[rot];
const batchSize = 16;
uint8_t* end = in + len - (len&0x0f); // 每个批次处理 16 字节,不够 16 字节的尾部要单独处理
uint8_t* start = in;
for (; start<end; start += batchSize, out += batchSize){
_mm_storeu_epi8( // step5: 把 16 个 int8 存储到目的地址
out, _mm512_cvtepi32_epi8( // step4: 把 16 个 int32 的查表结果,转换成 16 个 int8
_mm512_i32gather_epi32( // step3: 把 16 个 int32 当成偏移量,在 table 开始的地址里面查询. 最后一个参数 4,表示查表中每个元素的偏移量是 4 字节
_mm512_cvtepu8_epi32( // step2: 把 16 个 int8 转换成 16 个 int32
_mm_loadu_si128(start) // step1: 以非对齐的方式,从源地址加载 16 字节
), line, 4))
);
}
end = in + len;
for (; start<end; start++, out++){
*out = (uint8_t)line[*start];
}
}
编译命令行为:
gcc -o caesar caesar.c -g -w -mavx -mavx2 -mavx512f -mavx512vl -mavx512bw -O2
最后测试的结果为:
- 逐个字符查表:67.041 ns/op
- avx512 查表:36.371 ns/op