CTCLoss如何使用
CTCLoss如何使用
什么是CTC
CTC全称为Connectionist Temporal Classification,中文翻译不好类似“联结主义按时间分类”。
CTCLoss是一类损失函数,用于计算模型输出\(y\)和标签\(label\)的损失。
神经网络在训练过程中,是让\(loss\)减少的过程。常用于图片文字识别OCR和语音识别项目,因为CTCLoss计算过程中不需要\(y\)和\(label\)对齐,这样做的好处就是大幅的减轻了数据对齐标注的工作量,极大的提高了效率。
架构介绍
本文主要是介绍CTCLoss,这里介绍模型架构是为了更好的理解CTCLoss函数在整体的做用。现有一段原始数据,它可以是一张带文字的图片或一段说话的音频。
如图所示原始的声音通过DFT(离散傅立叶变化)得到一张具有时频特性的特征图,将特征图通过网络\(\mathcal{N}_w\)后输出结果\(y\)(\(y\in\mathbb{R}^{K \times T}\),\(K\)维是在每一时间点预测的词的概率,\(T\)是时间维度)。
一个简单的例子
现在有一段语音,是一个人在拼写英文单词“CAT”,语音内容是“C”、“A”、“T”这三个字母。这个人读完这三个字母用了5s的时间。我们想通过语音识别这三个字母。
首先我们需要一个26个字母的词表,我们用序号1-26,分别来表示字母A-Z这26个字母,我们用序号0表示blank。blank是用来区分那些不属于这26字母的部分。然后是假设这个模型每秒会给出一个识别字母表的概率分布,
音频持续了5s,因此有5列这样的概率分布。
下表就是\(y\)的概率分布,每一列是当前时刻输入数据所对应的概率分布。
\(y_t^k\) | t=1 | t=2 | t=3 | t=4 | t=5 |
---|---|---|---|---|---|
k=0 (-) | 0.031953 | 0.044296 | 0.038297 | 0.038320 | 0.027464 |
k=1 (A) | 0.026221 | 0.030363 | 0.031878 | 0.027295 | 0.029824 |
k=2 (B) | 0.040555 | 0.025838 | 0.023487 | 0.041529 | 0.028116 |
k=3 (C) | 0.029333 | 0.045889 | 0.031872 | 0.023184 | 0.029338 |
k=4 (D) | 0.023595 | 0.053792 | 0.022519 | 0.039882 | 0.025342 |
k=5 (E) | 0.048014 | 0.028887 | 0.020526 | 0.041302 | 0.045833 |
k=6 (F) | 0.028770 | 0.040735 | 0.045488 | 0.044244 | 0.032191 |
k=7 (G) | 0.035127 | 0.032281 | 0.034032 | 0.051973 | 0.041613 |
k=8 (H) | 0.044897 | 0.047910 | 0.049222 | 0.056956 | 0.048665 |
k=9 (I) | 0.032323 | 0.044911 | 0.038994 | 0.046017 | 0.040002 |
k=10 (J) | 0.047130 | 0.024608 | 0.034797 | 0.038146 | 0.041496 |
k=11 (K) | 0.033491 | 0.049294 | 0.043909 | 0.053962 | 0.037901 |
k=12 (L) | 0.044700 | 0.056019 | 0.046794 | 0.038094 | 0.027488 |
k=13 (M) | 0.045632 | 0.034822 | 0.052229 | 0.021692 | 0.039653 |
k=14 (N) | 0.035123 | 0.050406 | 0.019438 | 0.024067 | 0.056986 |
k=15 (O) | 0.023015 | 0.037482 | 0.046163 | 0.050536 | 0.058191 |
k=16 (P) | 0.031419 | 0.024302 | 0.035848 | 0.034614 | 0.031820 |
k=17 (Q) | 0.034497 | 0.025424 | 0.052284 | 0.049642 | 0.029912 |
k=18 (R) | 0.029572 | 0.031274 | 0.032931 | 0.026295 | 0.042725 |
k=19 (S) | 0.027484 | 0.044015 | 0.031383 | 0.037050 | 0.046068 |
k=20 (T) | 0.051330 | 0.047532 | 0.043297 | 0.040039 | 0.036849 |
k=21 (U) | 0.034691 | 0.045869 | 0.024400 | 0.022020 | 0.029838 |
k=22 (V) | 0.054835 | 0.028627 | 0.031971 | 0.039436 | 0.062661 |
k=23 (W) | 0.033373 | 0.035513 | 0.047827 | 0.030642 | 0.026361 |
k=24 (X) | 0.048700 | 0.022777 | 0.034515 | 0.022410 | 0.026991 |
k=25 (Y) | 0.033561 | 0.023278 | 0.045237 | 0.034797 | 0.027990 |
k=26 (Z) | 0.050657 | 0.023858 | 0.040665 | 0.025854 | 0.028682 |
上面的例子已经给出了网络\(\mathcal{N}_w\)输出\(y\)的描述,与这段音频所对应的标签\(label\),应该是‘C’、‘A’、‘T’这三个字母,将它转换成用字母表中序号表示
CTC计算的推导
在论文中CTCLoss的计算公式为
那上面这个公式表示的含义是什么呢?
-
符号\(S\)一个训练样本集合,它是总体分布的一个子集。
-
\(x,z \in S\),\(x\)是训练样本集合\(S\)中原始的数据经过网络\(\mathcal{N}_w\)后的输出,\(z\)是与\(x\)相对应的标签。
-
\(p(z|x)\)表示以\(\mathcal{N}_w\)的输出\(x\),将\(x\)恢复为标签\(z\)的概率,也就是\(z\)相对于\(x\)的条件概率。
-
这样将样本集合\(S\)中每一条样本的\(p(z|x)\)相乘,就是样本\(S\)对于\(\mathcal{N}_w\)似然函数:
\[L(S,\mathcal{N}_w)=\prod_{x,z \in S}{p(z|x)} \] -
我们通过训练调整网络\(\mathcal{N}_w\)的参数\(w\),使\(ln{(L(S,\mathcal{N}_w))}\)最大,这个过程就叫最大似然估计。
-
为了方便计算,我们在等式两边取\(ln\),这就是对数似然函数。
\[ln{(L(S,\mathcal{N}_w))}=\sum_{x,z \in S}{ln{(p(z|x))}} \] -
因为似然函数是越大表示结果越好,而损失函数是越小则表示结果越好所以需要一个负号
\[O^{ML}(S,\mathcal{N}_w)=-ln{(S,\mathcal{N}_w)}=-\sum_{x,z \in S}ln(p(z|x)) \]
总概率\(p(z|x)\)
CTCLoss中最关键的就是计算每一条样本\({\{x,z\}} \in S\)的条件概率\(p(z|x)\),\(z\)是目标标签与\(x\)是一一对应关系,\(l\)是任意标签只要是符合字母表规则的标签都是可以的,而\(z\)只是符合\(l\)规则中的一条。在训练的时候可以指定\(l=z\),但在公式推导时应该更严谨更泛化一些。因此\(p(z|x)\)可以用作\(p(l|x)\)替代,下面给出\(p(l|x)\)的计算公式
路径的含义
已知网络\(\mathcal{N}_w\)的输出\(x\in\mathbb{R}^{K \times T}\),它有\(T\)个时间点,并在每个时间点中有\(K\)种输出的可能,一共有\(K^ T\)条路径。在上面的例子中\(K=27,T=5\)所以一共就有\(27^5=14348907\)条可能的路径。仅仅\(T=5\)时,总路径条数已经相当的巨大了。
路径概率\(p(\pi|x)\)
表1已经给出于每个时刻所有的字母概率,由每个时刻选出的字母将组成一条路径,那么这条路径的概率就等于各个时刻选择字母的概率的乘积。
什么是\(\mathcal{B}\)变换
在上面提到的\(27^5\)条路径中\(\mathcal{B}\)变换就是将路径中所有的blank\((-)\),和相邻重复的元素删除,比如
同理符号\(\mathcal{B}^{-1}(l)\)则是\(\mathcal{B}(\pi)\)的逆变换。表示所有满足\(\mathcal{B}(\pi)=l\)的路径
\(p(l|x)\)并不是计算所有路径的概率之和,而是计算所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。
一步一步手动计算CTCLoss
现在就根据上面提供的例子,一步一步手动计算CTCLoss
找出所有满足\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”的路径
在上面给出的\(27^5\)条路径中给出的符合\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”共有28条,
如表2所示
t=1 | t=2 | t=3 | t=4 | t=5 | |
---|---|---|---|---|---|
\(\pi_{1}\) | - | - | C | A | T |
\(\pi_{2}\) | - | C | - | A | T |
\(\pi_{3}\) | - | C | C | A | T |
\(\pi_{4}\) | - | C | A | - | T |
\(\pi_{5}\) | - | C | A | A | T |
\(\pi_{6}\) | - | C | A | T | - |
\(\pi_{7}\) | - | C | A | T | T |
\(\pi_{8}\) | C | - | - | A | T |
\(\pi_{9}\) | C | - | A | - | T |
\(\pi_{10}\) | C | - | A | A | T |
\(\pi_{11}\) | C | - | A | T | - |
\(\pi_{12}\) | C | - | A | T | T |
\(\pi_{13}\) | C | C | - | A | T |
\(\pi_{14}\) | C | C | C | A | T |
\(\pi_{15}\) | C | C | A | - | T |
\(\pi_{16}\) | C | C | A | A | T |
\(\pi_{17}\) | C | C | A | T | - |
\(\pi_{18}\) | C | C | A | T | T |
\(\pi_{19}\) | C | A | - | - | T |
\(\pi_{20}\) | C | A | - | T | - |
\(\pi_{21}\) | C | A | - | T | T |
\(\pi_{22}\) | C | A | A | - | T |
\(\pi_{23}\) | C | A | A | A | T |
\(\pi_{24}\) | C | A | A | T | - |
\(\pi_{25}\) | C | A | A | T | T |
\(\pi_{26}\) | C | A | T | - | - |
\(\pi_{27}\) | C | A | T | T | - |
\(\pi_{28}\) | C | A | T | T | T |
计算每条路径的概率\(p(\pi|x)\)
路径\(\pi_1\)所对应的标签为"- - C A T",这段序列转换为字母表中的索引,
则路径\(\pi_1\)在每个时刻的取值如下
因此路径\(\pi_1的概率\)\(p(\pi_1|x)\)的计算如下
同理可计算
计算总概率\(p(l|x)\)
\(p(l|x)\)是所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。
计算损失函数CTCLoss
由于例子中只给了1样本,所以下面的损失函数CTCLoss也就只有这一个样本的损失。
CTCLoss库函数的验证
网络\(\mathcal{N}_w\)输出\(y\_out\)的softmax处理
这里有一点需要解释一下,CTCLoss的输入\(ctc\_input\)与网络\(\mathcal{N}_w\)的输出\(y\_out\)之间的关系。
在网络\(\mathcal{N}_w\)输出的最后一级是没有softmax,所以\(y\_out\)在每一个时间点的的概率和都不为1,为了将概率分布归一化需要将\(y\)进行softmax计算。
同时CTCLoss中包含有大量的概率的乘法运算,需要将\(y\_softmax\)进行\(ln\)计算,
这样可以将乘法转换为加法计算,提升计算的速度。
上面的例子,为了让文档更直观,已经默认
下表就是\(y\_out\),显然每一列之和不为1。
\(y\_out_t^k\) | t=1 | t=2 | t=3 | t=4 | t=5 |
---|---|---|---|---|---|
k=0 (-) | 0.347713 | 0.755077 | 0.678652 | 0.585987 | 0.123084 |
k=1 (A) | 0.149997 | 0.377396 | 0.495177 | 0.246735 | 0.205494 |
k=2 (B) | 0.586092 | 0.216019 | 0.189710 | 0.666416 | 0.146515 |
k=3 (C) | 0.262145 | 0.790407 | 0.495006 | 0.083483 | 0.189072 |
k=4 (D) | 0.044454 | 0.949304 | 0.147608 | 0.625960 | 0.042652 |
k=5 (E) | 0.754933 | 0.327565 | 0.054974 | 0.660945 | 0.635198 |
k=6 (F) | 0.242785 | 0.671264 | 0.850713 | 0.729752 | 0.281867 |
k=7 (G) | 0.442402 | 0.438645 | 0.560560 | 0.890752 | 0.538597 |
k=8 (H) | 0.687796 | 0.833501 | 0.929609 | 0.982303 | 0.695163 |
k=9 (I) | 0.359228 | 0.768854 | 0.696667 | 0.769029 | 0.499116 |
k=10 (J) | 0.736340 | 0.167254 | 0.582791 | 0.581446 | 0.535801 |
k=11 (K) | 0.394707 | 0.861980 | 0.815397 | 0.928313 | 0.445183 |
k=12 (L) | 0.683416 | 0.989872 | 0.879014 | 0.580090 | 0.123932 |
k=13 (M) | 0.704047 | 0.514423 | 0.988912 | 0.016983 | 0.490357 |
k=14 (N) | 0.442305 | 0.884281 | 0.000522 | 0.120860 | 0.852998 |
k=15 (O) | 0.019578 | 0.588026 | 0.865439 | 0.862711 | 0.873927 |
k=16 (P) | 0.330858 | 0.154752 | 0.612566 | 0.484297 | 0.270294 |
k=17 (Q) | 0.424309 | 0.199863 | 0.989950 | 0.844856 | 0.208461 |
k=18 (R) | 0.270270 | 0.406955 | 0.527680 | 0.209405 | 0.564980 |
k=19 (S) | 0.197054 | 0.748706 | 0.479523 | 0.552291 | 0.640312 |
k=20 (T) | 0.821721 | 0.825584 | 0.801348 | 0.629883 | 0.417029 |
k=21 (U) | 0.429921 | 0.789963 | 0.227843 | 0.031991 | 0.205976 |
k=22 (V) | 0.887771 | 0.318524 | 0.498094 | 0.614713 | 0.947933 |
k=23 (W) | 0.391183 | 0.534064 | 0.900852 | 0.362411 | 0.082071 |
k=24 (X) | 0.769114 | 0.089951 | 0.574661 | 0.049533 | 0.105709 |
k=25 (Y) | 0.396792 | 0.111706 | 0.845178 | 0.489570 | 0.142041 |
k=26 (Z) | 0.808514 | 0.136293 | 0.738640 | 0.192510 | 0.166460 |
pytorch库函数验证
CTCLoss使用细节可以参考pytorch官方文档
import torch
import torch.nn as nn
import numpy as np
y_softmax = np.array([
[[0.0319533345695271, 0.0262210133693412, 0.0405548727460100, 0.0293328834922530, 0.0235946021815836, 0.0480142162870594, 0.0287704618407728, 0.0351268637054168, 0.0448965052477630, 0.0323234212279283, 0.0471297269219778, 0.0334908192070999, 0.0447002788315031, 0.0456320948241136,
0.0351234600906292, 0.0230148922614546, 0.0314192811142228, 0.0344970346892286, 0.0295721871384341, 0.0274843752526059, 0.0513304969210734, 0.0346911732659917, 0.0548353372646645, 0.0333729892573427, 0.0486999624899632, 0.0335606882517763, 0.0506570275502634]],
[[0.0442961938109001, 0.0303627704208565, 0.0258378526020265, 0.0458891577161975, 0.0537920435977104, 0.0288868677848477, 0.0407349328912650, 0.0322806067098565, 0.0479099042067772, 0.0449106925711146, 0.0246080887866719, 0.0492939884049119, 0.0560191619281624, 0.0348218517081914,
0.0504056201105211, 0.0374815087428365, 0.0243023731122621, 0.0254237678526359, 0.0312736688595233, 0.0440148630768450, 0.0475321094768427, 0.0458687788283468, 0.0286268732637606, 0.0355125367928648, 0.0227774801386588, 0.0232784351056503, 0.0238578714997625]],
[[0.0382974368377362, 0.0318777312135849, 0.0234868589674224, 0.0318722744011979, 0.0225185381516373, 0.0205262552943881, 0.0454877627911883, 0.0340316234294017, 0.0492219436202117, 0.0389936131137926, 0.0347966678592871, 0.0439093761642613, 0.0467935124498177, 0.0522292290638150,
0.0194384495697102, 0.0461625681675025, 0.0358483354617907, 0.0522835019782284, 0.0329308772273817, 0.0313826141807340, 0.0432967801742709, 0.0243997674509821, 0.0319708630090250, 0.0478266566415420, 0.0345149265806327, 0.0452367066323343, 0.0406651295681235]],
[[0.0383195501689954, 0.0272951137973125, 0.0415288927451887, 0.0231838517718695, 0.0398823138441169, 0.0413022813256117, 0.0442442329310963, 0.0519730489462436, 0.0569558497142297, 0.0460166028890008, 0.0381459528257684, 0.0539623316283564, 0.0380942573161036, 0.0216922730554261,
0.0240667868142706, 0.0505358960731075, 0.0346143968556499, 0.0496415831760055, 0.0262949856792733, 0.0370498580320465, 0.0400391034751884, 0.0220202876462848, 0.0394362954874324, 0.0306423990773223, 0.0224099657044701, 0.0347974172676594, 0.0258544717519696]],
[[0.0274643882294982, 0.0298236175649923, 0.0281155092606543, 0.0293378537462717, 0.0253418924737544, 0.0458330002578632, 0.0321905618820226, 0.0416126048467898, 0.0486654573566434, 0.0400017201897758, 0.0414964341812715, 0.0379014590893513, 0.0274877024782956, 0.0396528862221281,
0.0569859416555112, 0.0581911831104043, 0.0318201830875284, 0.0299122412570334, 0.0427250763149338, 0.0460679863549903, 0.0368492548068844, 0.0298379764585031, 0.0626610201269008, 0.0263607892806820, 0.0269913345266294, 0.0279900073565483, 0.0286819178841385]]
]).astype("float32")
labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")
ctc_input = torch.tensor(y_softmax).log()
labels = torch.tensor(labels)
input_lengths = torch.tensor(input_lengths)
label_lengths = torch.tensor(label_lengths)
ctc_loss = nn.CTCLoss(reduction='none')
loss = ctc_loss(ctc_input, labels, input_lengths, label_lengths)
print('loss is {}'.format(loss))
loss is tensor([13.5036])
paddle库函数的使用
CTCLoss使用细节可以参考
paddle官方文档
由于paddle的CTCLoss库底层已经实现了log_softmax,所以它的输入可以直接为\(y\_out\)
import numpy as np
import paddle
import paddle.nn.functional as F
y_out = np.array([
[[0.347712671277525, 0.149997253831683, 0.586092067231462, 0.262145317727807, 0.0444540922782385, 0.754933267231179, 0.242785357820962, 0.442402313001943, 0.687796085120107, 0.359228210401861, 0.736340074301202, 0.394707475278763, 0.683415866967978, 0.704047430334266,
0.442305413383371, 0.0195776235533187, 0.330857880214071, 0.424309496833137, 0.270270423432065, 0.197053798095456, 0.821721184961310, 0.429921409383266, 0.887770954256354, 0.391182995461163, 0.769114387388296, 0.396791517013617, 0.808514095887345]],
[[0.755077099007084, 0.377395544835103, 0.216018915961394, 0.790407217966913, 0.949303911849797, 0.327565434075205, 0.671264370451740, 0.438644982586956, 0.833500595588975, 0.768854252429615, 0.167253545494722, 0.861980478702072, 0.989872153631504, 0.514423456505704,
0.884281023126955, 0.588026055308498, 0.154752348656045, 0.199862822857452, 0.406954837138907, 0.748705718215691, 0.825583815786156, 0.789963029944531, 0.318524245398992, 0.534064127370726, 0.0899506787705811, 0.111705744193203, 0.136292548938299]],
[[0.678652304800188, 0.495177019089661, 0.189710406017580, 0.495005824990221, 0.147608221976689, 0.0549741469061882, 0.850712674289007, 0.560559527354885, 0.929608866756663, 0.696667200555228, 0.582790965175840, 0.815397211477421, 0.879013904597178, 0.988911616079589,
0.000522375356944771, 0.865438591013025, 0.612566469483999, 0.989950205708831, 0.527680069338442, 0.479523385210219, 0.801347605521952, 0.227842935706042, 0.498094291196390, 0.900852488532005, 0.574661219130188, 0.845178185054037, 0.738640291995402]],
[[0.585987035826476, 0.246734525985975, 0.666416217319468, 0.0834828136026227, 0.625959785171583, 0.660944557947342, 0.729751855317221, 0.890752116325322, 0.982303222883606, 0.769029085335896, 0.581446487875398, 0.928313062314188, 0.580090365758442, 0.0169829383372613,
0.120859571098558, 0.862710718699670, 0.484296511212103, 0.844855674576263, 0.209405084020935, 0.552291341538775, 0.629883385064421, 0.0319910157625669, 0.614713419117141, 0.362411462273053, 0.0495325790420612, 0.489569989177322, 0.192510396062075]],
[[0.123083747545945, 0.205494170907680, 0.146514910614890, 0.189072174472614, 0.0426524109111434, 0.635197916859882, 0.281866855880430, 0.538596678045340, 0.695163039444332, 0.499116013482590, 0.535801055751113, 0.445183165296042, 0.123932277598070, 0.490357293468018,
0.852998155340816, 0.873927405861733, 0.270294332292698, 0.208461358751314, 0.564979570738201, 0.640311825162758, 0.417028951642886, 0.205975515532243, 0.947933121293169, 0.0820712070977259, 0.105709426581721, 0.142041121903998, 0.166460440876421]]
]).astype("float32")
labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")
y_out=paddle.to_tensor(y_out)
labels = paddle.to_tensor(labels)
input_lengths = paddle.to_tensor(input_lengths)
label_lengths = paddle.to_tensor(label_lengths)
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(y_out, labels,
input_lengths,
label_lengths)
print('loss is {}'.format(loss))
loss is Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[13.50364304])