SoftMax 回归(与Logistic 回归的联系与区别)

SoftMax 回归(与Logistic 回归的联系与区别)

SoftMax 试图解决的问题

SoftMax回归模型是Logistic回归模型在多分类问题上的推广,即在多分类问题中,类标签y可以取两个以上的值

对于Logistic回归的假设函数\(h_\theta(x) = \frac{1}{1 + \exp(-x)}\),它的输出结果将被投影到\([0,1]\)区间上,根据假设函数的输出值的大小,我们预测该输入值是否属于某一个类别,其结果只会是不是,即Logistic回归只能解决二分类问题.

SoftMax实现多分类的思路很简单: 对于每一个分类,输出一个假设值,用于判定当前输入值对应该类的概率,最终根据各个类的概率大小判定该输入值对应的分类. 可以看出,SoftMax的思想有点类似于独热编码

SoftMax回归的假设函数,代价函数及正则化

SoftMax的假设函数如下:

\[h_\theta(x^{(i)}) = \left[ \begin{matrix} & p(y^{(1)}| x^{(i)};\theta_1)\\ & p(y^{(2)}|x^{(i)};\theta_2)\\ &:\\ &p(y^{(k)}|x^{(i)};\theta_k) \end{matrix} \right] = \frac{1}{\sum_{j=1}^{k}e^{\theta_j^T}x^{(i)}} \left[ \begin{matrix} & e^{\theta_1^Tx^{(i)}}\\ & e^{\theta_2^Tx^{(i)}}\\ &:\\ & e^{\theta_k^Tx^{(i)}} \end{matrix} \right] \]

其中\(\theta_1,\theta_2,...,\theta_k \in \R^{n+1}\)是训练模型的参数,右式分母意义在于归一化至\([0,1]\)区间,使得所有概率和为一

转化为矩阵运算,我们用

\[\Theta = \left[ \begin{matrix} &\theta_1^T \\ &\theta_2^T \\ &:\\ &\theta_k^T \end{matrix} \right] \]

来表示所有的参数. 显然,\(\Theta\)是一个\(k*(n+1)\)的矩阵

代价函数为

image-20201203193353937

image-20201203193525550

image-20201203193556867

image-20201203193615180

K-Logistic 与 SoftMax 的选择

image-20201203193632942


Update: 部分代码节选

下面简单介绍利用softmax训练模型时,所需实现的求解损失函数、梯度的代码

function [cost,grad] = softMaxCost(theta,n,k,X,oneHoty,lambda)
% 参数介绍
% theta 传入为一个向量
% n为输入维度
% k为最终分类树
% X为n*m的输入样本,每一列代表一个样本
% lambda为正则项
% - - - - - - - - - - - - - - - - - - - - - - - - - - - -
theta = reshape(theta,k,n); % 规模k*n
m = size(X,2); % 样本数
groudTruth = oneHoty; % 真实值
% - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
cost = 0;
thetagrad = zeros(size(theta));
% - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
temp = theta * X;
temp = exp(temp);
H = bsxfun(@rdivide,temp,sum(temp)); 
temp = log(H);
temp = groudTruth .* temp;
% - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
cost = -1/m * sum(sum(temp)) + lambda/2*sum(sum(theta.^2));
thetagrad = -1/m * (groundTruth-H)*X' + lambda * theta;
grad = [thetagrad(:)];
end
posted @ 2020-12-03 19:38  popozyl  阅读(1001)  评论(0编辑  收藏  举报