asm内联语句在编译不支持的情况下的解决办法

asm内联语句在编译不支持的情况下的解决办法

本文链接

在编写 CUDA 的 C++ 程序时,NVIDIA 的库文件中包含这样一行代码,这行代码在 NVIDIA的编译器 nvcc 编译时可以正常通过,但迁移到 AMD 的编译器中将无法识别,生成报错。为了便于做迁移工作,我需要获得这行代码的执行信息。

首先观察一下这行报错代码。

asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));

asm代码信息

询问GPT可以得出下面的回答:

这行代码是使用ARM架构的NEON指令集进行编程的一个例子。

NEON是一种SIMD(单指令多数据)架构扩展,用于ARM处理器,它允许开发者通过单条指令对多个数据进行并行处理。

这行代码的具体作用是。

asm volatile: 这是GCC的扩展语法,用于在C或C++代码中嵌入汇编语言。volatile关键字告诉编译器不要对这段汇编代码进行任何优化。

"cvt.rni.sat.s8.f16 %0, %1;" 这是NEON指令的一部分。这条指令的具体作用是:
cvt: 表示转换(convert)。

rni: 表示舍入到最接近的整数,如果有两个整数同样接近,则舍入到偶数。

sat: 表示饱和(saturation),即如果结果超出了目标类型的范围,则将其限制在目标类型的最大或最小值。

s8: 表示目标数据类型是8位有符号整数。

f16: 表示源数据类型是16位浮点数。

%0%1: 分别表示第一个和第二个操作数的位置。

"=h"(int16): 这部分指定了输出操作数%0。"=h"表示这是一个16位(half)的输出操作数,并且它将被存储在变量int16中。

"h"(int16_in): 这部分指定了输入操作数%1。"h"表示这是一个16位(half)的输入操作数,它来自变量int16_in

综上所述,这行代码的作用是将16位浮点数int16_in转换为8位有符号整数,结果存储在int16中。转换过程中会进行舍入和饱和处理。

提取其中的关键信息,可以发现。

> rni: 表示舍入到最接近的整数,如果有两个整数同样接近,则舍入到偶数。
> sat: 表示饱和(saturation),即如果结果超出了目标类型的范围,则将其限制在目标类型的最大或最小值。

此时根据这些信息,可以联想到这是一行作饱和舍入的代码。对于饱和舍入(Saturation Rounding),其具体定义:


  • 饱和舍入(Saturation Rounding)是一种数值处理方法,常用于数字信号处理和图像处理领域。在饱和舍入中,当一个数值需要被转换或舍入到某个特定的数值范围时,如果这个数值超出了目标范围,它不会被简单地截断或进行标准的四舍五入,而是被“饱和”到目标范围的最小值或最大值。

  • 例如,假设我们有一个8位的有符号整数范围,这个范围是从-128到127。如果一个数值在转换过程中计算出的结果是130,那么按照饱和舍入的规则,这个数值会被饱和到127,因为这是这个范围内的最大值。同样,如果一个数值计算出的结果是-130,它会被饱和到-128,因为这是这个范围内的最小值。

  • 饱和舍入的好处是它避免了数据溢出的问题,保持了数据的完整性,并且在某些应用中,如图像处理,它有助于防止图像质量的下降。


有了这些前置信息,我们就可以知道,这行代码做了两件事,就是将传入的数据做了一次舍入操作,再对数据范围做了截取。对于舍入方式,其中也有表明: 舍入到最接近的整数(rni)

rni 是“round to nearest integer”的缩写,表示舍入到最接近的整数。

这种舍入方式遵循以下规则:

如果小数部分正好是0.5,那么结果会舍入到最接近的偶数。这被称为“银行家舍入”或“四舍六入五成双”。
如果小数部分小于0.5,那么结果会向下舍入到更小的整数。
如果小数部分大于或等于0.5,那么结果会向上舍入到更大的整数。
例如,使用rni舍入方法:

> 1.5   舍入为  2
> 2.5   舍入为  2
> -1.5  舍入为  -2
> -2.5  舍入为  -2

在 AMD 支持的内联asm汇编语句和寄存器类型中找不到上述的实现,而且寄存器类型的符号表示也有所差别。于是我采取最简单的实现方式,将这行内联汇编语句直接替换为 C 语句,实现其功能。由于已知了其功能,编写 C 程序也十分的简单。但是由于这行内联语句是直接调用寄存器,运行速度比用 C 编写的语法快,所以简单用 C 替换仅仅是实现了其正确性,性能有所不及。

但是我在 AMD 上尚未找到有对应的汇编指令完成这行代码的实现,因此目前不得不使用这种方法。

语句替换

在不考虑数据类型转换的情况下,我们先来看舍入的规则。以下的数据是实际在 NVIDIA 编译器上调用asm上述代码所实现的结果。可以看到,当数值超过 127 或者小于 -128 的时候,会将数据截断在 127 和 -128 处。这也是8位有符号整数int8_t所能表示的范围(-128 ~ 127)。

> -150.0 舍入为  -128
> -128.0 舍入为  -128
> -1.0   舍入为  -1
> -1.6   舍入为  -2
> -1.5   舍入为  -2
> -1.4   舍入为  -1
> -1.0   舍入为  -1
> 0.0    舍入为  0
> 0.4    舍入为  0
> 0.5    舍入为  0
> 0.6    舍入为  1
> 1.0    舍入为  1
> 126.0  舍入为  126
> 127.0  舍入为  127
> 128.0  舍入为  127
> 200.0  舍入为  127

内联语句中规定了输入输出的操作数类型,输入是一个16位(half)的输入操作数,从我的上下文中可以得知,传入时的类型是 half 类型。输出是一个16位的操作数,并以此指定了操作16位数据的寄存器(h),但是传出的数据类型是int8_t, int8_t是8位数据。

从上面可以得知,我们需要的结果数据储存在 int8_t 类型中就已经足够,内联语句中调用的却是16位的寄存器。因此需要对产生的16位数据进行截取才能获得需要的8位数值。内联语句中的 s8 其实就表示输出的数据类型为8位,只不过借用了16位的寄存器而已。

直接对 halfint8_t 类型之间做转换会产生错误,因为它们不仅数据存储长度不同,表示数值的方式也是不一样的。为了保险起见,可以用 floatint 类型的局部变量储存住数值,作为中间变量,将 half 类型的浮点数转换为期望得到的整数数值。

至于将数值截取到 -128 ~ 127 之间,可以直接将超过范围的数值置为端点值。

__device__ int8_t cvt_f16_to_s8(half val)
{
    float float32 = (float)val;
    int int32 = 0;
    if (float32 > 0)
    {
        if (float32 > 127)
            int32 = 127;
        else
            int32 = (int)(float32 + 0.5);   // 强制数据类型转换
    }
    else if (float32 < 0)
    {
        if (float32 < -128)
            int32 = -128;
        else
            int32 = (int)(float32 - 0.5);  // 强制数据类型转换
    }
    return *((int8_t *)&int32);
}

我的舍入操作中,强制数据类型转换发生在 floatint 类型之间,这样可以保证数值截取时得到预期数值大小。返回值通过得到的 int 类型数值地址,转换为 int8_t * 的指针,并取这个 int8_t 的值返回,这样可以保证返回值是 int8_t 类型。至此完成了上述内联汇编语句的全部功能。将这个 cvt_f16_to_s8(half val); 函数替换掉 asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); 即可。

__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
    union
    {
        int8_t int8[2];
        int16_t int16;
    };
    union
    {
        half fp16;
        int16_t int16_in;
    };
    fp16 = val;

    //asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
    int8_t res = cvt_f16_to_s8(val);  // 通过 C 的语法,用函数实现
    
    return res;
}

__device__ 是运行在 GPU 上的函数 kernel 声明方式,在这里不用在意。 如果对GPU编程感兴趣,可以移步我CUDA入门的教程文档。

CUDA入门必看,如何高效地编写并行程序

posted @ 2024-09-18 17:15  北纬31是条纬线哦  阅读(56)  评论(0编辑  收藏  举报