【机器学习】softmax回归

Softmax Regression

(多标签分类)将多输入的分类值转化为[0,1]的概率分布,进而进行逻辑回归算法

softmax能将差距大的数值距离拉得更大,但是数值可能会溢出

Softmax Function

数学表达式

\[a_j = \frac{e^{z_j}}{ \sum_{k=1}^{N}{e^{z_k} }} \]

代码

def my_softmax(z):
    ez = np.exp(z)              #element-wise exponenial
    sm = ez/np.sum(ez)
    return(sm)

cost function

数学表达式

loss:

\[\begin{equation} L(\mathbf{a},y)=\begin{cases} -log(a_1), & \text{if $y=1$}.\\ &\vdots\\ -log(a_N), & \text{if $y=N$} \end{cases} = -log\left(\frac{e^{z_2}}{ \sum_{i=1}^{N}{e^{z_i} }}\right) \end{equation} \]

化简

\[L(\mathbf{z})= -\left[z_2 - log( \sum_{i=1}^{N}{e^{z_i} })\right] = \underbrace{log \sum_{i=1}^{N}{e^{z_i} }}_\text{logsumexp()} -z_2 = C+ log( \sum_{i=1}^{N}{e^{z_i-C} }) -z_2 \;\;\;\text{where } C=max_j(\mathbf{z}) \]

cost function

\[\begin{align} J(\mathbf{w},b) = - \left[ \sum_{i=1}^{m} \sum_{j=1}^{N} 1\left\{y^{(i)} == j\right\} \log \frac{e^{z^{(i)}_j}}{\sum_{k=1}^N e^{z^{(i)}_k} }\right] \end{align} \]

代码

# compute f_x
model = Sequential(
    [ 
        Dense(25, activation = 'relu'),
        Dense(15, activation = 'relu'),
        Dense(4, activation = 'softmax')    # < softmax activation here
    ]
)
# compute loss
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.001),
)
# gradient descent
model.fit(
    X_train,y_train,
    epochs=10
)

Numerical Stability

该模型使用指数函数,因而大数容易溢出,需要处理

数学原理

\[\begin{align} a_j &= \frac{e^{z_j}}{ \sum_{i=1}^{N}{e^{z_i} }} \frac{e^{-max_j(\mathbf{z})}}{ {e^{-max_j(\mathbf{z})}}} \\ &= \frac{e^{z_j-max_j(\mathbf{z})}}{ \sum_{i=1}^{N}{e^{z_i-max_j(\mathbf{z})} }} \end{align} \]

化简

\[a_j = \frac{e^{z_j-C}}{ \sum_{i=1}^{N}{e^{z_i-C} }} \quad\quad\text{where}\quad C=max_j(\mathbf{z}) \]

代码

def my_softmax_ns(z):
    """numerically stablility improved"""
    bigz = np.max(z)
    ez = np.exp(z-bigz)              # minimize exponent
    sm = ez/np.sum(ez)
    return(sm)

多类分类和多标签分类不一样

  • 多标签:无关物品分类
  • 多类:有概率关系的特征分类
posted @ 2023-07-31 15:44  码农要战斗  阅读(22)  评论(0编辑  收藏  举报