why cross entropy loss works
cross entropy
is measurement of probability distributions p
and q
over the same underlying random variable.
H
(
p
,
q
)
=
−
∑
x
i
∈
X
p
(
x
i
)
l
o
g
q
(
x
i
)
H(p, q) = -\sum_{x_i \in X} p(x_i)log^{q(x_i)}
H(p,q)=−xi∈X∑p(xi)logq(xi)
Speaking of classification problem, the random variable X X X represents the probable category of a instance. p ( x i ) p(x_i) p(xi) or q ( x i ) q(x_i) q(xi) is the probability that the instance is belonged to category x i x_i xi. They are belong to different classification system. p ( x i ) p(x_i) p(xi) is known from the training data. q ( x i ) q(x_i) q(xi) is produced by the algorithm. The goal of cross entropy loss is to make q ( x i ) q(x_i) q(xi) be same to p ( x i ) p(x_i) p(xi), so that the algorithm makes the right classification.
why cross entropy loss works
The short answer is when
q
(
x
i
)
q(x_i)
q(xi) is same to
p
(
x
i
)
p(x_i)
p(xi), the
H
(
p
,
q
)
H(p, q)
H(p,q) becomes the minimum.
To make the problem simple, let’s take binary classification as example. The category of an instance denote as
X
=
{
x
0
,
x
1
}
X = \{x_0, x_1\}
X={x0,x1}. As it’s binary classification, there is a relationship:
p
(
x
1
)
=
1
−
p
(
x
0
)
p(x_1) = 1 - p(x_0)
p(x1)=1−p(x0)
The ralation of
p
(
x
0
)
p(x_0)
p(x0) and
e
n
t
r
o
p
y
(
X
)
entropy(X)
entropy(X) is as shown as
The entropy of certain data will correspond with one point in the curve. The curve covered any probabilities, so makes a line.
The cross entropy of
p
(
.
)
p(.)
p(.) and
q
(
.
)
q(.)
q(.), where
p
(
.
)
=
q
(
.
)
p(.)=q(.)
p(.)=q(.), will be
The cross entropy of
p
(
.
)
p(.)
p(.) and
q
(
.
)
q(.)
q(.), where any condition is taken in account, will be
As shown in the figure, no matter what distribution the true distribution
p
(
.
)
p(.)
p(.) is, and no matter what distribution the algorithm produced distribution
q
(
.
)
q(.)
q(.) is, as the cross entropy goes to be the minimum, the distribution
q
(
.
)
q(.)
q(.) goes to be same to
p
(
.
)
p(.)
p(.). If the algorithm produces the right distribution, it produces the right classification. That’s why minimizing the cross entropy loss makes the algorithm produce the right classification.
script of plotting
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 1/4/2019 10:04 PM
# @Author : yusisc (yusisc@gmail.com)
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
# Cross entropy - Wikipedia
# https://en.wikipedia.org/wiki/Cross_entropy
# 2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation
# https://matplotlib.org/gallery/mplot3d/mixed_subplots.html
fig = plt.figure(figsize=plt.figaspect(0.4))
p0 = np.linspace(0, 1, 20)
p0 = p0[1: -1]
# print(p.size)
entropy = -(p0 * np.log(p0) +
(1 - p0) * np.log(1 - p0))
ax0 = fig.add_subplot(1, 3, 1)
ax0.plot(p0, entropy)
ax0.set_xlabel('p(x0)', color='r')
ax0.set_ylabel('entropy of X', color='r')
ax0.set_title('entropy of random var X')
ax1 = fig.add_subplot(1, 3, 2, projection='3d')
ax1.plot(p0, p0, entropy)
ax1.set_xlabel('p(x0)', color='r')
ax1.set_ylabel('shadow of p(x0)', color='r')
ax1.set_zlabel('entropy of X', color='r')
ax1.set_title('cross entropy of distribution p()\nand p() over random var X')
p, q = np.meshgrid(p0, p0)
cross_entropy = -(p * np.log(q) +
(1-p) * np.log(1 - q))
ax2 = fig.add_subplot(1, 3, 3, projection='3d')
ax2.plot(p0, p0, entropy, 'r+')
ax2.plot_surface(p, q, cross_entropy, cmap=cm.coolwarm,
linewidth=0, antialiased=False, alpha=0.7)
ax2.set_xlabel('p(x0)', color='r')
ax2.set_ylabel('q(x0)', color='r')
ax2.set_zlabel('\n\ncross entropy \n of distribution p() and q() \n over the random variable X', color='r')
ax2.set_title('cross entropy of distribution p()\nand q() over random var X')
plt.show()
reference
Cross entropy - Wikipedia
https://en.wikipedia.org/wiki/Cross_entropy
2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation
https://matplotlib.org/gallery/mplot3d/mixed_subplots.html