LOADING . . .

Jensen不等式的可视化

  Jenson不等式描述对于一个凸函数,期望值与函数作用后的期望值之间的关系。

  对于积分为1的非负函数$p(x)$,即

$\displaystyle\int_{-\infty}^{\infty}p(x) dx = 1$

  假设$f(x)$为下凸函数,$g(x)$为任意可测函数,Jenson不等式定义如下:

$\displaystyle f\left(\int_{-\infty}^{\infty}p(x)g(x)dx\right) \leq \int_{-\infty}^{\infty}p(x)f(g(x))dx$

  转换成期望的形式就是:

$\displaystyle f(E_{x\sim p(x)}g(x)) \leq E_{x\sim p(x)} f(g(x))$

例子

  下图举了一个一般离散情况下的Jenson不等式的理解例子:

  直观理解,对于下凸函数$f(x)$,四个点$\{(x_i,f(x_i))\}_{i=0}^3$的加权和点被包在由这些点确定的凸多边形内部。加权和点作垂线到$x$轴与$f(x)$的交点即可得到$f(E(x_i))$,由于下凸,可以直观看到$f(E(x_i))$在$E(f(x_i))$的下面。

  画图代码:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
# plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams.update({'font.size': 15})
config = {
    "font.family":'Times New Roman',
    "mathtext.fontset":'stix',
}
rcParams.update(config)

# 定义一个简单的凸函数 f(x) = x^2
def f(x):
    return x ** 2

# 定义x的4个值
x_values = np.array([-3, -1, 3, 4])
f_values = f(x_values)

# 定义每个点的权重
weights = np.array([0.3, 0.1, 0.2, 0.4])

# 计算加权平均
mean_x = np.sum(weights * x_values)
mean_f_x = np.sum(weights * f_values)

# 创建图形
fig, ax = plt.subplots(figsize=(8, 6))

# 画出凸函数
x = np.linspace(np.min(x_values)-1, np.max(x_values)+1, 500)
ax.plot(x, f(x), label=r'$f(x) = x^2$', color='blue')

# 画出四个点
ax.scatter(x_values, f_values, color='black', zorder=5)
for i in range(4):
    ax.text(x_values[i]+0.5, f_values[i]+0.5, f'($x_%d$, $f(x_%d)$)'%(i, i), fontsize=15, ha='center', color='black')

# 画出加权平均点
ax.scatter(mean_x, mean_f_x, color='red', zorder=5)
ax.text(mean_x, mean_f_x+0.5, f'Weighted Mean', fontsize=15, ha='center', color='red')

ax.plot([mean_x, mean_x], [0, mean_f_x], 'red', linestyle='--')
ax.text(mean_x, -1, r'$E(x_i)$', fontsize=15, ha='center', color='red')

# 添加从加权平均点到y轴的垂直线
ax.plot([-4, mean_x], [mean_f_x, mean_f_x], 'red', linestyle='--')
ax.text(-4, mean_f_x, r'$E(f(x_i))$', fontsize=15, va='center', ha='right', color='red')

f_mean_x = f(mean_x)
ax.plot([mean_x, mean_x], [0, mean_f_x], 'red', linestyle='--')
ax.plot([-4, mean_x], [f_mean_x, f_mean_x], 'red', linestyle='--')
ax.text(-4, f_mean_x, f'$f(E(x_i))$', fontsize=15, va='center', ha='right', color='red')
ax.scatter(mean_x, f_mean_x, color='red', zorder=5)

for i in range(4):
    ax.plot([x_values[i], mean_x], [f_values[i], mean_f_x], color='gray', linestyle='--')

# 连接四个点的线
x_values_loop = np.append(x_values, x_values[0])  # 添加第一个点到最后,形成闭环
f_values_loop = np.append(f_values, f_values[0])  # 同样处理 f(x) 值
ax.plot(x_values_loop, f_values_loop, color='gray', linestyle='--', marker='o', markersize=5, zorder=4)


# 添加标题和标签
ax.set_title('Jensen\'s Inequality Illustration', fontsize=16)
ax.set_xlabel('$x$', fontsize=14)
ax.set_ylabel('$f(x)$', fontsize=14)
ax.set_ylim(0)
ax.set_xlim(-4)

# 显示图形
plt.grid(False)
plt.legend()
plt.savefig('t.svg')
plt.show()

 

posted @ 2024-12-07 14:53  颀周  阅读(67)  评论(0编辑  收藏  举报
很高兴能帮到你~
点赞