编写程序打印所有 FP8(E4M3)格式的浮点数

FP8 (e4m3) 是一种浮点数表示格式,其中:

  • e4 表示有 4 位用于指数。
  • m3 表示有 3 位用于尾数(即有效数字)。

为了生成并打印所有 FP8 (e4m3) 格式的浮点数,我们需要了解这个格式的细节。FP8 (e4m3) 的格式可以分解为:

  • 符号位(1 位)
  • 指数(4 位)
  • 尾数(3 位)

我们可以用以下步骤来生成所有可能的 FP8 (e4m3) 浮点数:

  1. 计算 FP8 (e4m3) 浮点数的不同组合。
  2. 处理特殊值,例如零和无穷大。
  3. 将这些组合转换为浮点数值并打印出来。

以下是 Python 程序的示例代码:

import numpy as np

def fp8_to_float(sign, exponent, mantissa):
    # FP8 e4m3 的指数偏移量为 7
    EXP_BIAS = 7
    # 计算实际指数
    exp = exponent - EXP_BIAS
    
    # 特殊值处理
    if exponent == 0 and mantissa == 0:
        return 0.0 if sign == 0 else -0.0
    elif exponent == 0b1111:  # 指数全1,处理无穷大和NaN
        if mantissa == 0:
            return float('inf') if sign == 0 else float('-inf')
        else:
            return float('nan')
    
    # 计算尾数的值
    mantissa_value = mantissa / (2**3)  # 尾数范围是 [0, 1)
    
    # 计算最终浮点值
    return (-1)**sign * (1 + mantissa_value) * (2 ** exp)

def generate_fp8_values():
    for sign in range(2):  # 0 和 1
        for exponent in range(16):  # 4 位二进制范围
            for mantissa in range(8):  # 3 位二进制范围
                value = fp8_to_float(sign, exponent, mantissa)
                print(f"sign={sign}, exponent={exponent:04b}, mantissa={mantissa:03b} => {value}")

generate_fp8_values()

代码解释

  1. fp8_to_float 函数:

    • 计算实际的浮点值。
    • 使用了指数偏移量(bias),在 FP8 (e4m3) 中,偏移量为 7。
    • 处理了特殊值(如零、无穷大和 NaN)。
  2. generate_fp8_values 函数:

    • 遍历所有可能的符号位、指数位和尾数位的组合。
    • 使用 fp8_to_float 函数计算实际浮点值,并打印结果。

运行此代码将打印 FP8 (e4m3) 格式中所有可能的浮点数及其对应的值。

打印出这些浮点数的点图

import numpy as np
import matplotlib.pyplot as plt

def fp8_to_float(sign, exponent, mantissa):
    EXP_BIAS = 7
    exp = exponent - EXP_BIAS
    
    if exponent == 0 and mantissa == 0:
        return 0.0 if sign == 0 else -0.0
    elif exponent == 0b1111:
        if mantissa == 0:
            return float('inf') if sign == 0 else float('-inf')
        else:
            return float('nan')
    
    mantissa_value = mantissa / (2**3)
    return (-1)**sign * (1 + mantissa_value) * (2 ** exp)

def generate_fp8_values():
    values = []
    for sign in range(2):
        for exponent in range(16):
            for mantissa in range(8):
                value = fp8_to_float(sign, exponent, mantissa)
                values.append(value)
    return values

def plot_fp8_values(values):
    # 过滤掉 NaN 和 inf
    values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
    
    plt.figure(figsize=(10, 6))
    plt.scatter(range(len(values)), values, marker='o')
    plt.title('FP8 (e4m3) Float Point Values')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.grid(True)
    plt.show()

# 生成 FP8 浮点数值
values = generate_fp8_values()
# 绘制点图
plot_fp8_values(values)

输出图

E4M3输出图

E5M2输出图

** 注意观察 y 轴的输出范围,可以看到 E5M2 的值域更广,所以相比 E4M3 不容易溢出,进而导致梯度上升。故E5M2 更适合训练模型。**

posted @ 2024-08-15 13:25  立体风  阅读(92)  评论(0编辑  收藏  举报