why cross entropy loss works
cross entropy
is measurement of probability distributions p
and q
over the same underlying random variable.
H
(
p
,
q
)
=
−
∑
x
i
∈
X
p
(
x
i
)
l
o
g
q
(
x
i
)
H(p, q) = -\sum_{x_i \in X} p(x_i)log^{q(x_i)}
H(p,q)=−xi∈X∑p(xi)logq(xi)
Speaking of classification problem, the random variable X X X represents the probable category of a instance. p ( x i ) p(x_i) p(xi) or q ( x i ) q(x_i) q(xi) is the probability that the instance is belonged to category x i x_i xi. They are belong to different classification system. p ( x i ) p(x_i) p(xi) is known from the training data. q ( x i ) q(x_i) q(xi) is produced by the algorithm. The goal of cross entropy loss is to make q ( x i ) q(x_i) q(xi) be same to p ( x i ) p(x_i) p(xi), so that the algorithm makes the right classification.
why cross entropy loss works
The short answer is when
q
(
x
i
)
q(x_i)
q(xi) is same to
p
(
x
i
)
p(x_i)
p(xi), the
H
(
p
,
q
)
H(p, q)
H(p,q) becomes the minimum.
To make the problem simple, let’s take binary classification as example. The category of an instance denote as
X
=
{
x
0
,
x
1
}
X = \{x_0, x_1\}
X={x0,x1}. As it’s binary classification, there is a relationship:
p
(
x
1
)
=
1
−
p
(
x
0
)
p(x_1) = 1 - p(x_0)
p(x1)=1−p(x0)
The ralation of
p
(
x
0
)
p(x_0)
p(x0) and
e
n
t
r
o
p
y
(
X
)
entropy(X)
entropy(X) is as shown as
The entropy of certain data will correspond with one point in the curve. The curve covered any probabilities, so makes a line.
The cross entropy of
p
(
.
)
p(.)
p(.) and
q
(
.
)
q(.)
q(.), where
p
(
.
)
=
q
(
.
)
p(.)=q(.)
p(.)=q(.), will be
The cross entropy of
p
(
.
)
p(.)
p(.) and
q
(
.
)
q(.)
q(.), where any condition is taken in account, will be
As shown in the figure, no matter what distribution the true distribution
p
(
.
)
p(.)
p(.) is, and no matter what distribution the algorithm produced distribution
q
(
.
)
q(.)
q(.) is, as the cross entropy goes to be the minimum, the distribution
q
(
.
)
q(.)
q(.) goes to be same to
p
(
.
)
p(.)
p(.). If the algorithm produces the right distribution, it produces the right classification. That’s why minimizing the cross entropy loss makes the algorithm produce the right classification.
script of plotting
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # @Time : 1/4/2019 10:04 PM # @Author : yusisc (yusisc@gmail.com) import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm # Cross entropy - Wikipedia # https://en.wikipedia.org/wiki/Cross_entropy # 2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation # https://matplotlib.org/gallery/mplot3d/mixed_subplots.html fig = plt.figure(figsize=plt.figaspect(0.4)) p0 = np.linspace(0, 1, 20) p0 = p0[1: -1] # print(p.size) entropy = -(p0 * np.log(p0) + (1 - p0) * np.log(1 - p0)) ax0 = fig.add_subplot(1, 3, 1) ax0.plot(p0, entropy) ax0.set_xlabel('p(x0)', color='r') ax0.set_ylabel('entropy of X', color='r') ax0.set_title('entropy of random var X') ax1 = fig.add_subplot(1, 3, 2, projection='3d') ax1.plot(p0, p0, entropy) ax1.set_xlabel('p(x0)', color='r') ax1.set_ylabel('shadow of p(x0)', color='r') ax1.set_zlabel('entropy of X', color='r') ax1.set_title('cross entropy of distribution p()\nand p() over random var X') p, q = np.meshgrid(p0, p0) cross_entropy = -(p * np.log(q) + (1-p) * np.log(1 - q)) ax2 = fig.add_subplot(1, 3, 3, projection='3d') ax2.plot(p0, p0, entropy, 'r+') ax2.plot_surface(p, q, cross_entropy, cmap=cm.coolwarm, linewidth=0, antialiased=False, alpha=0.7) ax2.set_xlabel('p(x0)', color='r') ax2.set_ylabel('q(x0)', color='r') ax2.set_zlabel('\n\ncross entropy \n of distribution p() and q() \n over the random variable X', color='r') ax2.set_title('cross entropy of distribution p()\nand q() over random var X') plt.show()
reference
Cross entropy - Wikipedia
https://en.wikipedia.org/wiki/Cross_entropy
2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation
https://matplotlib.org/gallery/mplot3d/mixed_subplots.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· AI技术革命,工作效率10个最佳AI工具