why cross entropy loss works

cross entropy is measurement of probability distributions p and q over the same underlying random variable.
H ( p , q ) = − ∑ x i ∈ X p ( x i ) l o g q ( x i ) H(p, q) = -\sum_{x_i \in X} p(x_i)log^{q(x_i)} H(p,q)=xiXp(xi)logq(xi)

Speaking of classification problem, the random variable X X X represents the probable category of a instance. p ( x i ) p(x_i) p(xi) or q ( x i ) q(x_i) q(xi) is the probability that the instance is belonged to category x i x_i xi. They are belong to different classification system. p ( x i ) p(x_i) p(xi) is known from the training data. q ( x i ) q(x_i) q(xi) is produced by the algorithm. The goal of cross entropy loss is to make q ( x i ) q(x_i) q(xi) be same to p ( x i ) p(x_i) p(xi), so that the algorithm makes the right classification.

why cross entropy loss works

The short answer is when q ( x i ) q(x_i) q(xi) is same to p ( x i ) p(x_i) p(xi), the H ( p , q ) H(p, q) H(p,q) becomes the minimum.
To make the problem simple, let’s take binary classification as example. The category of an instance denote as X = { x 0 , x 1 } X = \{x_0, x_1\} X={x0,x1}. As it’s binary classification, there is a relationship:
p ( x 1 ) = 1 − p ( x 0 ) p(x_1) = 1 - p(x_0) p(x1)=1p(x0)
The ralation of p ( x 0 ) p(x_0) p(x0) and e n t r o p y ( X ) entropy(X) entropy(X) is as shown as
在这里插入图片描述
The entropy of certain data will correspond with one point in the curve. The curve covered any probabilities, so makes a line.
The cross entropy of p ( . ) p(.) p(.) and q ( . ) q(.) q(.), where p ( . ) = q ( . ) p(.)=q(.) p(.)=q(.), will be
在这里插入图片描述
The cross entropy of p ( . ) p(.) p(.) and q ( . ) q(.) q(.), where any condition is taken in account, will be
在这里插入图片描述
As shown in the figure, no matter what distribution the true distribution p ( . ) p(.) p(.) is, and no matter what distribution the algorithm produced distribution q ( . ) q(.) q(.) is, as the cross entropy goes to be the minimum, the distribution q ( . ) q(.) q(.) goes to be same to p ( . ) p(.) p(.). If the algorithm produces the right distribution, it produces the right classification. That’s why minimizing the cross entropy loss makes the algorithm produce the right classification.

script of plotting

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 1/4/2019 10:04 PM
# @Author  : yusisc (yusisc@gmail.com)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Cross entropy - Wikipedia
# https://en.wikipedia.org/wiki/Cross_entropy
# 2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation
# https://matplotlib.org/gallery/mplot3d/mixed_subplots.html

fig = plt.figure(figsize=plt.figaspect(0.4))

p0 = np.linspace(0, 1, 20)
p0 = p0[1: -1]
# print(p.size)
entropy = -(p0 * np.log(p0) +
            (1 - p0) * np.log(1 - p0))

ax0 = fig.add_subplot(1, 3, 1)
ax0.plot(p0, entropy)
ax0.set_xlabel('p(x0)', color='r')
ax0.set_ylabel('entropy of X', color='r')
ax0.set_title('entropy of random var X')

ax1 = fig.add_subplot(1, 3, 2, projection='3d')
ax1.plot(p0, p0, entropy)
ax1.set_xlabel('p(x0)', color='r')
ax1.set_ylabel('shadow of p(x0)', color='r')
ax1.set_zlabel('entropy of X', color='r')
ax1.set_title('cross entropy of distribution p()\nand p() over random var X')

p, q = np.meshgrid(p0, p0)
cross_entropy = -(p * np.log(q) +
                  (1-p) * np.log(1 - q))

ax2 = fig.add_subplot(1, 3, 3, projection='3d')
ax2.plot(p0, p0, entropy, 'r+')
ax2.plot_surface(p, q, cross_entropy, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False, alpha=0.7)
ax2.set_xlabel('p(x0)', color='r')
ax2.set_ylabel('q(x0)', color='r')
ax2.set_zlabel('\n\ncross entropy \n of distribution p() and q() \n over the random variable X', color='r')
ax2.set_title('cross entropy of distribution p()\nand q() over random var X')

plt.show()

reference

Cross entropy - Wikipedia
https://en.wikipedia.org/wiki/Cross_entropy
2D and 3D Axes in same Figure — Matplotlib 3.0.2 documentation
https://matplotlib.org/gallery/mplot3d/mixed_subplots.html

posted on 2019-01-06 14:12  yusisc  阅读(26)  评论(0编辑  收藏  举报

导航