Loading [MathJax]/jax/element/mml/optable/GeneralPunctuation.js

Hiroki

大部分笔记已经转移到 https://github.com/hschen0712/machine_learning_notes ,QQ:357033150, 欢迎交流

softmax分类器+cross entropy损失函数的求导

softmax是logisitic regression在多酚类问题上的推广,W=[w1,w2,...,wc]为各个类的权重因子,b为各类的门槛值。不要想象成超平面,否则很难理解,如果理解成每个类的打分函数,则会直观许多。预测时我们把样本分配到得分最高的类。

Notations:

  • x:输入向量,d×1列向量,d是feature数
  • W:权重矩阵,c×d矩阵,c是label数
  • b:每个类对应超平面的偏置组成的向量, c×1列向量
  • z=Wx+b:线性分类器输出, c×1列向量
  • ˆy:softmax函数输出, c×1列向量
  • ej=[0,...,1,...,0]TRc×1,其中1出现在第j个位置
  • 1c表示一个全1c维列向量
  • y:我们要拟合的目标变量,是一个one-hot vector(只有一个1,其余均为0),也是 c×1列向量 。 我们将其转置,表示为一个列向量:

y=[0,...,1,...,0]T

他们之间的关系:

{z=Wx+bˆy=softmax(z)=exp(z)1Tcexp(z)

cross-entropy error定义为:

CE(z)=yTlog(ˆy)

因为y是一个one-hot vector(即只有一个位置为1),假设yk=1,那么上式等于log(ˆyk)=log(exp(zk)iexp(zi))=zk+log(iexp(zi))

依据chain rule有:

CE(z)Wij=tr((CE(z)z)TzWij)=tr((ˆyzCE(z)ˆy)TzWij)

注:这里我用了Denominator layout,因此链式法则是从右往左的。

我们一个一个来求。

ˆyz=(exp(z)1Tcexp(z))z=11Tcexp(z)exp(z)z+(11Tcexp(z))z(exp(z))T=11Tcexp(z)diag(exp(z))1(1Tcexp(z))2exp(z)exp(z)T=diag(exp(z)1Tcexp(z))exp(z)1Tcexp(z)(exp(z)1Tcexp(z))T=diag(softmax(z))softmax(z)softmax(z)T=diag(ˆy)ˆyˆyT

注:上述求导过程使用了Denominator layout
设$a=a( \boldsymbol{ x}),\boldsymbol{u}= \boldsymbol{u}( \boldsymbol{x}) \boldsymbol{ x}a \boldsymbol{u}加粗表示是一个向量函数。在`Numerator layout`下,\frac{\partial a \boldsymbol{u}}{ \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+ \boldsymbol{u}\frac{\partial a}{\partial \boldsymbol{x}} ,而在`Denominator layout`下,则为\frac{\partial a \boldsymbol{u}}{\partial \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+\frac{\partial a}{\partial \boldsymbol{x}} \boldsymbol{u}^T$,对比可知上述推导用的实际是Denominator layout
以下推导均采用 Denominator layout,这样的好处是我们用梯度更新权重时不需要对梯度再转置。

\begin{equation}\frac{\partial CE(z)}{\partial \hat{y}}=\frac{\partial log(\hat{y})}{\partial \hat{y}}\cdot \frac{\partial (-y^Tlog(\hat{y}))}{\partial log(\hat{y})}=\big(diag(\hat{y})\big)^{-1}\cdot(-y)\label{eq2}\end{equation}

z的第k个分量可以表示为:z_k=\sum\limits_j W_{kj}x_j+b_k,因此

\begin{equation}\frac{\partial z}{\partial W_{ij}} =\begin{bmatrix}\frac{\partial z_1}{\partial W_{ij}}\\\vdots\\\frac{\partial z_c}{\partial W_{ij}}\end{bmatrix}=[0,\cdots, x_j,\cdots, 0]^T=x_j \vec{e}_i\label{eq3}\end{equation}

其中x_j是向量x的第j个元素,为标量,它出现在第i行。
综合\eqref{eq1},\eqref{eq2},\eqref{eq3},我们有

\begin{aligned}\frac{\partial CE(z)}{\partial W_{ij}}&=tr\bigg(\big( (diag(\hat{y})-\hat{y}\hat{y}^T)\cdot (diag(\hat{y}))^{-1} \cdot (-y) \big)^T\cdot x_j \vec{e}_i \bigg)\\&=tr\bigg(\big( \hat{y}\cdot (1_c^Ty)-y\big)^T\cdot x_j \vec{e}_i \bigg)\\&=(\hat{y}-y)^T\cdot x_j \vec{e}_i={err}_ix_j\end{aligned}

其中{err}_i=(\hat{y}-y)_i表示残差向量的第i

我们可以把上式改写为

\frac{\partial CE(z)}{\partial W}=(\hat{y}-y)\cdot x^T

同理可得

\frac{\partial CE(z)}{\partial b}=(\hat{y}-y)

那么在进行随机梯度下降的时候,更新式就是:

\begin{aligned}&W \leftarrow W - \lambda (\hat{y}-y)\cdot x^T \\&b \leftarrow b - \lambda (\hat{y}-y)\end{aligned}

其中\lambda是学习率

posted on   Hiroki  阅读(10907)  评论(0编辑  收藏  举报

编辑推荐:
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 用 C# 插值字符串处理器写一个 sscanf
· Java 中堆内存和栈内存上的数据分布和特点
· 开发中对象命名的一点思考
· .NET Core内存结构体系(Windows环境)底层原理浅谈
阅读排行:
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 本地部署DeepSeek后,没有好看的交互界面怎么行!
· DeepSeek 解答了困扰我五年的技术问题。时代确实变了!
· 趁着过年的时候手搓了一个低代码框架
· 推荐一个DeepSeek 大模型的免费 API 项目!兼容OpenAI接口!
< 2025年2月 >
26 27 28 29 30 31 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 1
2 3 4 5 6 7 8

导航

统计

点击右上角即可分享
微信分享提示