C++分段函数的SSE指令优化

对于分段函数,当其输入参数x属于不同的范围区间时,其表达式是不一样的。按照通常的C++实现来说,如果要计算4个不同x的函数值,则需要分别判断这4个x所属于的范围区间,然后根据各自所属的范围区间来决定各自的函数表达式,所以这4个x的函数表达式很可能是不一样的。然而SSE指令优化的核心思想就是在一条CPU指令内同时对4个数进行相同的运算,如果这4个数的运算不一样,那么将无法使用SSE指令对其同时处理。

在这种情况下,我们可以想办法把分段函数转化成一个统一的表达式。假如有以下分段函数:

那么我们可以把F(x)构造成以下形式:

上式中,可以通过分别比较x与a1、a2的大小来决定c1、c2、c3的值,如下表所示。由于三个条件是互斥的,c1、c2、c3中只有一个是1,其余两个是0,因此构造的函数与原分段函数是等效的。


x<a1
a1≤x≤a2
x>a2
c1
1
0
0
c2
0
1
0
c3
0
01

在SSE指令中,假如比较两个浮点数大小的结果是true,其返回的不是1,而是0xffff;如果是false则返回0x0。所以F(x)需要修改为下式,即使用按位与&符号代替乘号*。

下面举一个更加具体的例子,分别给出C++实现代码和SSE指令优化代码。分段函数为:

该函数的C++代码如下,该代码每调用一次只能计算一个x的函数值。

float cal_coeff(float x, float a1, float a2)
{
  if(x > a2)
    return x;
  else if(x < a1)
    return x*x/a1;
  else if( x >= a1 && x <= a2)    //y=a(x-h)2+k(a≠0),(h,k)为抛物线的顶点
    return (x-a2)*(x-a2)/(a1-a2) + a2;
}

下面我们使用SSE指令进行优化,优化之后每调用函数一次,同时求得4个不同x值的函数值。首先我们复习一下SSE指令中__m128数据类型的结构,__m128总共具有128 bits数据,也就是4个float型数据,如下表。SSE指令运算的基本数据类型为__m128,也就是说其同时处理4个float型数据。

3
2
1
0
float 3
float 2float 1float 0

SSE指令优化代码如下。

__m128 cal_coeff_sse(__m128 x, __m128 a1, __m128 a2, __m128 a1_a2)
{
  __m128 c11 = _mm_cmpge_ps(x, a1);   //返回一个_m128的寄存器,比较x≥a1,分别比较寄存器a的每个32bit浮点数是否大于寄存器b对应位置32bit浮点数,若大于,该位置返回0xffff,否则返回0x0
  __m128 c22 = _mm_cmple_ps(x, a2);   //比较x≤a2,分别比较寄存器a的每个32bit浮点数是否小于寄存器b对应位置32bit浮点数,若大于,该位置返回0xffff,否则返回0x0
  __m128 c1 = _mm_cmpgt_ps(x, a2);   //比较x>a2
  __m128 c2 = _mm_and_ps(c11, c22);  //比较是否a1≤x≤a2,返回为一个_m128的寄存器,将寄存器a和寄存器b的对应位置的32bit单精度浮点数分别进行按位与运算
  __m128 c3 = _mm_cmplt_ps(x, a1);   //比较x<a1
  
  //x>a2的情况
  __m128 f1 = _mm_and_ps(x, c1);     //计算c1&f1(x) 


  //a1≤x≤a2的情况
  __m128 f2 = _mm_sub_ps(x, a2);  //x-a2    
  f2 = _mm_mul_ps(f2, f2);     //(x-a2)*(x-a2)
  f2 = _mm_div_ps(f2, a1_a2);   //(x-a2)*(x-a2)/(a1-a2)
  f2 = _mm_add_ps(f2, a2);     //(x-a2)*(x-a2)/(a1-a2)+a2
  f2 = _mm_and_ps(f2, c2);    //计算c2&f2(x)


  //x<a1的情况
  __m128 f3 = _mm_mul_ps(x, x);   //x*x
  f3 = _mm_div_ps(f3, a1);     //x*x/a1
  f3 = _mm_and_ps(f3, c3);     //计算c3&f3(x)


  //计算F(x) = c1&f1(x) + c2&f2(x) + c3&f3(x)
  __m128 f = _mm_add_ps(_mm_add_ps(f1, f2), f3);


  return f;
}
posted @ 2020-10-18 22:40  萌萌哒程序猴  阅读(263)  评论(0编辑  收藏  举报