混合密度网络(MDN)进行多元回归详解和代码示例

在本文中,首先简要解释一下 混合密度网络 MDN (Mixture Density Network)是什么,然后将使用Python 代码构建 MDN 模型,最后使用构建好的模型进行多元回归并测试效果。

回归

“回归预测建模是逼近从输入变量 (X) 到连续输出变量 (y) 的映射函数 (f) [...] 回归问题需要预测具体的数值。 具有多个输入变量的问题通常被称为多元回归问题 例如,预测房屋价值,可能在 100,000 美元到 200,000 美元之间

这是另一个区分分类问题和回归问题的视觉解释如下:

另外一个例子

密度

DENSITY “密度” 是什么意思? 这是一个快速的通俗示例:

假设正在为必胜客运送比萨。 现在记录刚刚进行的每次交付的时间(以分钟为单位)。 交付 1000 次后,将数据可视化以查看工作表现如何。 这是结果:

这是披萨交付时间数据分布的“密度”。平均而言,每次交付需要 30 分钟(图中的峰值)。 它还表示,在 95% 的情况下(2 个标准差2sd ),交付需要 20 到 40 分钟才能完成。 密度种类代表时间结果的“频率”。 “频率”和“密度”的区别在于:

· 频率:如果你在这条曲线下绘制一个直方图并对所有的 bin 进行计数,它将求和为任何整数(取决于数据集中捕获的观察总数)。

· 密度:如果你在这条曲线下绘制一个直方图并计算所有的 bin,它总和为 1。我们也可以将此曲线称为概率密度函数 (pdf)。

用统计术语来说,这是一个漂亮的正态/高斯分布。 这个正态分布有两个参数:

均值

· 标准差:“标准差是一个数字,用于说明一组测量值如何从平均值(平均值)或预期值中展开。低标准偏差意味着大多数数字接近平均值。高标准差意味着数字更加分散。“

均值和标准差的变化会影响分布的形状。 例如:

有许多具有不同类型参数的各种不同分布类型。 例如:

混合密度

现在让我们看看这 3 个分布:

如果我们采用这种双峰分布(也称为一般分布):

混合密度网络使用这样的假设,即任何像这种双峰分布的一般分布都可以分解为正态分布的混合(该混合也可以与其他类型的分布一起定制 例如拉普拉斯):

网络架构

混合密度网络也是一种人工神经网络。 这是神经网络的经典示例:

输入层(黄色)、隐藏层(绿色)和输出层(红色)。

如果我们将神经网络的目标定义为学习在给定一些输入特征的情况下输出连续值。 在上面的例子中,给定年龄、性别、教育程度和其他特征,那么神经网络就可以进行回归的运算。

密度网络

密度网络也是神经网络,其目标不是简单地学习输出单个连续值,而是学习在给定一些输入特征的情况下输出分布参数(此处为均值和标准差)。 在上面的例子中,给定年龄、性别、教育程度等特征,神经网络学习预测期望工资分布的均值和标准差。预测分布比预测单个值具有很多的优势,例如能够给出预测的不确定性边界。 这是解决回归问题的“贝叶斯”方法。下面是预测每个预期连续值的分布的一个很好的例子:

下面的图片向我们展示了每个预测实例的预期值分布:

混合密度网络

最后回到正题,混合密度网络的目标是在给定特定输入特征的情况下,学习输出混合在一般分布中的所有分布的参数(此处为均值、标准差和 Pi)。 新参数“Pi”是混合参数,它给出最终混合中给定分布的权重/概率。

最终结果如下:

示例1:单变量数据的 MDN 类

完整文章

https://www.overfit.cn/post/20245a8446ae43e3982b48e4320991ab

posted @ 2022-02-19 12:05  deephub  阅读(604)  评论(0编辑  收藏  举报