编写程序打印所有 FP8(E4M3)格式的浮点数
FP8 (e4m3) 是一种浮点数表示格式,其中:
e4
表示有 4 位用于指数。m3
表示有 3 位用于尾数(即有效数字)。
为了生成并打印所有 FP8 (e4m3) 格式的浮点数,我们需要了解这个格式的细节。FP8 (e4m3) 的格式可以分解为:
- 符号位(1 位)
- 指数(4 位)
- 尾数(3 位)
我们可以用以下步骤来生成所有可能的 FP8 (e4m3) 浮点数:
- 计算 FP8 (e4m3) 浮点数的不同组合。
- 处理特殊值,例如零和无穷大。
- 将这些组合转换为浮点数值并打印出来。
以下是 Python 程序的示例代码:
import numpy as np
def fp8_to_float(sign, exponent, mantissa):
# FP8 e4m3 的指数偏移量为 7
EXP_BIAS = 7
# 计算实际指数
exp = exponent - EXP_BIAS
# 特殊值处理
if exponent == 0 and mantissa == 0:
return 0.0 if sign == 0 else -0.0
elif exponent == 0b1111: # 指数全1,处理无穷大和NaN
if mantissa == 0:
return float('inf') if sign == 0 else float('-inf')
else:
return float('nan')
# 计算尾数的值
mantissa_value = mantissa / (2**3) # 尾数范围是 [0, 1)
# 计算最终浮点值
return (-1)**sign * (1 + mantissa_value) * (2 ** exp)
def generate_fp8_values():
for sign in range(2): # 0 和 1
for exponent in range(16): # 4 位二进制范围
for mantissa in range(8): # 3 位二进制范围
value = fp8_to_float(sign, exponent, mantissa)
print(f"sign={sign}, exponent={exponent:04b}, mantissa={mantissa:03b} => {value}")
generate_fp8_values()
代码解释
-
fp8_to_float
函数:- 计算实际的浮点值。
- 使用了指数偏移量(bias),在 FP8 (e4m3) 中,偏移量为 7。
- 处理了特殊值(如零、无穷大和 NaN)。
-
generate_fp8_values
函数:- 遍历所有可能的符号位、指数位和尾数位的组合。
- 使用
fp8_to_float
函数计算实际浮点值,并打印结果。
运行此代码将打印 FP8 (e4m3) 格式中所有可能的浮点数及其对应的值。
打印出这些浮点数的点图
import numpy as np
import matplotlib.pyplot as plt
def fp8_to_float(sign, exponent, mantissa):
EXP_BIAS = 7
exp = exponent - EXP_BIAS
if exponent == 0 and mantissa == 0:
return 0.0 if sign == 0 else -0.0
elif exponent == 0b1111:
if mantissa == 0:
return float('inf') if sign == 0 else float('-inf')
else:
return float('nan')
mantissa_value = mantissa / (2**3)
return (-1)**sign * (1 + mantissa_value) * (2 ** exp)
def generate_fp8_values():
values = []
for sign in range(2):
for exponent in range(16):
for mantissa in range(8):
value = fp8_to_float(sign, exponent, mantissa)
values.append(value)
return values
def plot_fp8_values(values):
# 过滤掉 NaN 和 inf
values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
plt.figure(figsize=(10, 6))
plt.scatter(range(len(values)), values, marker='o')
plt.title('FP8 (e4m3) Float Point Values')
plt.xlabel('Index')
plt.ylabel('Value')
plt.grid(True)
plt.show()
# 生成 FP8 浮点数值
values = generate_fp8_values()
# 绘制点图
plot_fp8_values(values)
输出图
E4M3输出图
E5M2输出图
** 注意观察 y 轴的输出范围,可以看到 E5M2 的值域更广,所以相比 E4M3 不容易溢出,进而导致梯度上升。故E5M2 更适合训练模型。**