Jensen不等式的可视化
Jenson不等式描述对于一个凸函数,期望值与函数作用后的期望值之间的关系。
对于积分为1的非负函数$p(x)$,即
$\displaystyle\int_{-\infty}^{\infty}p(x) dx = 1$
假设$f(x)$为下凸函数,$g(x)$为任意可测函数,Jenson不等式定义如下:
$\displaystyle f\left(\int_{-\infty}^{\infty}p(x)g(x)dx\right) \leq \int_{-\infty}^{\infty}p(x)f(g(x))dx$
转换成期望的形式就是:
$\displaystyle f(E_{x\sim p(x)}g(x)) \leq E_{x\sim p(x)} f(g(x))$
例子
下图举了一个一般离散情况下的Jenson不等式的理解例子:
直观理解,对于下凸函数$f(x)$,四个点$\{(x_i,f(x_i))\}_{i=0}^3$的加权和点被包在由这些点确定的凸多边形内部。加权和点作垂线到$x$轴与$f(x)$的交点即可得到$f(E(x_i))$,由于下凸,可以直观看到$f(E(x_i))$在$E(f(x_i))$的下面。
画图代码:
import numpy as np import matplotlib.pyplot as plt from matplotlib import rcParams # plt.rcParams['text.usetex'] = True plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' plt.rcParams['font.family'] = 'serif' plt.rcParams['font.serif'] = ['Computer Modern'] plt.rcParams.update({'font.size': 15}) config = { "font.family":'Times New Roman', "mathtext.fontset":'stix', } rcParams.update(config) # 定义一个简单的凸函数 f(x) = x^2 def f(x): return x ** 2 # 定义x的4个值 x_values = np.array([-3, -1, 3, 4]) f_values = f(x_values) # 定义每个点的权重 weights = np.array([0.3, 0.1, 0.2, 0.4]) # 计算加权平均 mean_x = np.sum(weights * x_values) mean_f_x = np.sum(weights * f_values) # 创建图形 fig, ax = plt.subplots(figsize=(8, 6)) # 画出凸函数 x = np.linspace(np.min(x_values)-1, np.max(x_values)+1, 500) ax.plot(x, f(x), label=r'$f(x) = x^2$', color='blue') # 画出四个点 ax.scatter(x_values, f_values, color='black', zorder=5) for i in range(4): ax.text(x_values[i]+0.5, f_values[i]+0.5, f'($x_%d$, $f(x_%d)$)'%(i, i), fontsize=15, ha='center', color='black') # 画出加权平均点 ax.scatter(mean_x, mean_f_x, color='red', zorder=5) ax.text(mean_x, mean_f_x+0.5, f'Weighted Mean', fontsize=15, ha='center', color='red') ax.plot([mean_x, mean_x], [0, mean_f_x], 'red', linestyle='--') ax.text(mean_x, -1, r'$E(x_i)$', fontsize=15, ha='center', color='red') # 添加从加权平均点到y轴的垂直线 ax.plot([-4, mean_x], [mean_f_x, mean_f_x], 'red', linestyle='--') ax.text(-4, mean_f_x, r'$E(f(x_i))$', fontsize=15, va='center', ha='right', color='red') f_mean_x = f(mean_x) ax.plot([mean_x, mean_x], [0, mean_f_x], 'red', linestyle='--') ax.plot([-4, mean_x], [f_mean_x, f_mean_x], 'red', linestyle='--') ax.text(-4, f_mean_x, f'$f(E(x_i))$', fontsize=15, va='center', ha='right', color='red') ax.scatter(mean_x, f_mean_x, color='red', zorder=5) for i in range(4): ax.plot([x_values[i], mean_x], [f_values[i], mean_f_x], color='gray', linestyle='--') # 连接四个点的线 x_values_loop = np.append(x_values, x_values[0]) # 添加第一个点到最后,形成闭环 f_values_loop = np.append(f_values, f_values[0]) # 同样处理 f(x) 值 ax.plot(x_values_loop, f_values_loop, color='gray', linestyle='--', marker='o', markersize=5, zorder=4) # 添加标题和标签 ax.set_title('Jensen\'s Inequality Illustration', fontsize=16) ax.set_xlabel('$x$', fontsize=14) ax.set_ylabel('$f(x)$', fontsize=14) ax.set_ylim(0) ax.set_xlim(-4) # 显示图形 plt.grid(False) plt.legend() plt.savefig('t.svg') plt.show()