one-hot 编码
def onehot(labels):
'''one-hot 编码'''
#数据有几行输出
n_sample = len(labels)
#数据分为几类。因为编码从0开始所以要加1
n_class = max(labels) + 1
#建立一个batch所需要的数组,全部赋0.
onehot_labels = np.zeros((n_sample, n_class))
#对每一行的,对应分类赋1
onehot_labels[np.arange(n_sample), labels] = 1
return onehot_labels
运行结果:
label=np.array([0,1,2])
onehot(label)
Out[8]:
array([[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]])
label=np.array([0,4,7,1,1,1,4,1])
onehot(label)
Out[10]:
array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 1.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.]])
总结:本次标签只有一类,如第一个标签为一类,有两种情况。第二个为标签一类,有七种情况。如果标签为两类,比如{男生,女生}、{一年级、二年级、三年级},那编码的长度为5.
onehot标签则是顾名思义,一个长度为n的数组,只有一个元素是1.0,其他元素是0.0。
想想为什么要这样编码,知乎大佬的的一个解释感觉极有道理。
使用onehot的直接原因是现在多分类cnn网络的输出通常是softmax层,而它的输出是一个概率分布,从而要求输入的标签也以概率分布的形式出现,进而算交叉熵之类。