ctc loss

原文地址:

https://zhuanlan.zhihu.com/p/23309693

https://zhuanlan.zhihu.com/p/23293860

 

CTC:前向计算例子

这里我们直接使用warp-ctc中的变量进行分析。我们定义T为RNN输出的结果的维数,这个问题的最终输出维度为alphabet_size。而ground_truth的维数为L。也就是说,RNN输出的结果为alphabet_size*T的结果,我们要将这个结果和1*L这个向量进行对比,求出最终的Loss。

我们要一步一步地揭开这个算法的细节……当然这个算法的实现代码有点晦涩……

我们的第一步要顺着test_cpu.cpp的路线来分析代码。第一步我们就是要解析small_test()中的内容。也就是做前向计算,计算对于RNN结果来说,对应最终的ground_truth——t的label的概率。

这个计算过程可以用动态规划的算法求解。我们可以用一个变量来表示动态规划的中间过程,它就是:

\alpha^T_i:表示在RNN计算的时间T时刻,这一时刻对应的ground_truth的label为第i个下标的值t[i]的概率。

这样的表示有点抽象,我们用一个实际的例子来讲解:

RNN结果:[R_1,R_2,R_3,R_4],这里的每一个变量都对应一个列向量。

ground_truth:[g_1,g_2,g_3]

那么\alpha^2_1表示R_2的结果对应着g_1的概率,当然与此同时,前面的结果也都合理地对应完成。

从上面的结果我们可以看出,如果R_2的结果对应着g_1,那么R_1的结果也必然对应着g_1。所以前面的结果是确定的。然而对于其他的一些情况来说,我们的转换存在着一定的不确定性。

CTC:前向计算具体过程

我们还是按照上面的例子进行计算,我们把刚才的例子搬过来:

RNN结果:[R_1,R_2,R_3,R_4],这里的每一个变量都对应一个列向量。

ground_truth:[g_1,g_2,g_3]

alphabet:[g_0(blank),g_1,g_2,g_3]

按照上面介绍的计算方法,第一步我们先做ground_truth的状态扩展,于是我们就把长度从3扩展到了7,现在的ground_truth变成了:

[blank,g_1,blank,g_2,blank,g_3,blank]

我们的RNN结果长度为4,也就是说我们会从上面的7个ground_truth状态中进行转移,并最终转移到最终状态。理论上利用动态规划的算法,我们需要计算4*7=28个中间结果。好了,下面我们用P^T_i表示RNN的第T时刻状态为ground_truth中是第i个位置的概率。

那么我们就开始计算了:

T=1时,我们只能选择g_1和blank,所以这一轮我们终结状态只可能落在0和1上。所以第一轮变成了:

[P^1_0,P^1_1,0,0,0,0,0]

T=2时,我们可以继续选择g_1,我们同时也可以选择g_2,还可以选择g_1g_2之间的blank,所以我们可以进一步关注这三个位置的概率,于是我们将其他的位置的概率设为0。[0,(P^1_0 +P^1_1)P^2_1,P^1_1P^2_2,P^1_1P^2_3,0,0,0]

T=3时,留给我们的时间已经不多了,我们还剩2步,要走完整个旅程,我们只能选择g_2g_3以及它们之间的空格。于是乎我们关心的位置又发生了变化:

[0,0,0,
(P^1_1P^2_2+P^1_1P^2_3)P^3_3,
P^1_1P^2_3P^3_4,
P^1_1P^2_3P^3_5,
0]

是不是有点看晕了?没关系,因为还剩最后一步了。下面是最后一步,因为最后一步我们必须要到g_3以及它后面的空格了,所以我们的概率最终计算也就变成了:

[0,0,0,
0,0,
((P^1_1P^2_2+P^1_1P^2_3)P^3_3+P^1_1P^2_2P^3_4+P^1_1P^2_2P^3_3)P^4_5,
P^1_1P^2_3P^3_5P^4_6]

好吧,最终的结果我们求出来了,实际上这就是通过时间的推移不断迭代求解出来的。关于迭代求解的公式这里就不再赘述了。我们直接来看一张图:

于是乎我们从这个计算过程中发现一些问题:

首先是一个相对简单的问题,我们看到在计算过程中我们发现了大量的连乘。由于每一个数字都是浮点数,那么这样连乘下去,最终数字有可能非常小而导致underflow。所以我们要将这个计算过程转到对数域上。这样我们就将其中的乘法转变成了加法。但是原本就是加法的计算呢?比方说我们现在计算了loga和logb,我们如何计算log(a+b)呢,这里老司机给出了解决方案,我们假设两个数中a>b,那么有

log(a+b)=log(a(1+\frac{b}{a}))=loga+log(1+\frac{b}{a})
=loga+log(1+exp(log(\frac{b}{a})))=loga+log(1+exp(logb - loga))

这样我们就利用了loga和logb计算出了log(a+b)来。

另外一个问题就是,我们发现在刚才的计算过程当中,对于每一个时间段,我们实际上并不需要计算每一个ground-truth位置的概率信息,实际上只要计算满足某个条件的某一部分就可以了。所以我们有没有希望在计算前就规划好这条路经,以保证我们只计算最相关的那些值呢?

如何控制计算的数量?

不得不说,这一部分warp-ctc写得实在有点晦涩,当然也可能是我在这方面的理解比较渣。我们这里主要关注两个部分——一个是数据的准备,一个是最终的数据的使用。

在介绍数据准备之前,我们先简单说一下这部分计算的大概思路。我们用两个变量start和end表示我们需要计算的状态的起止点,在每一个时间点,我们要更新start和end这两个变量。然后我们更新start和end之间的概率信息。这里我们先要考虑一个问题,start和end的更新有什么规律?

为了简化思考,我们先假设ground_truth中没有重复的label,我们的大脑瞬间得到了解放。好了,下面我们就要给出代码中的两个变量——

T:表示RNN结果中的维度

S/2:ground_truth的维度(S表示了扩展blank之后的维度)

基本上具备一点常识,我们就可以知道T>=S/2。什么?你觉得有可能出现T<S/2的情况?兄弟,这种见鬼的事情如果发生,你难道要我们把RNN的结果拆开给你用?臣妾不太能做得到啊……

好了,既然接受了上面的事实,那么我们就来举几个例子看看:

我们假设T=3,S/2=3,那么说白了,它们之间的对应关系是一一对应,说白了这就和blank位置没啥关系了。在T=1时,我们要转移到第一个结果,T=2,我们要转移到第二个结果……

 

如何控制计算的数量?cont.

好,废话少说我们书接上回。不明真相的小朋友先看这个:

下面我们假设T=4,S/2=3,好玩的地方来了。T比S/2多一个,也就是说我们允许冗余出现了,那么我们可能的形式也就变多了。我们可以增加一个blank,我们也可以在没有label位置原地打一轮酱油。选择更多,欢乐更多。

虽然选择变多,但是着并不意味着我们可以选择任意一种状态转移的方式,至少:

  • 在T=2时,我们至少要转移到第一个结果
  • 在T=3时,我们至少要转移到第二个结果
  • 在T=4时,兄弟我们准备下车了

这其实就是对start的限制。源代码中有这样一句话:

int remain = (S / 2) + repeats - (T - t);

这里我们先忽略repeats,那么remain这个变量其实是在计算label数量和剩余时间的差。如果用这样的语言来表达刚才的那个问题,我们语言就变成这个样子:

  • 当时间还剩4轮时(包括第4轮),我们在哪都无所谓(实际上是从T=1开始计算的)
  • 当时间还剩3轮时(包括第3轮),我们至少要转移到第一个结果(index=1)
  • 当时间还剩2轮时(包括第2轮),我们至少要转移到第二个结果(index=3)
  • 当时间还剩1轮时(包括第1轮),我们至少要转移到第三个结果(index=5)

好了,这里我们看出其中的含义了。我们再啰嗦一下,看看这些变量随T的变化情况:

  • T=1,remain=0,start+=1
  • T=2,remain=1,start+=2
  • T=3,remain=2,start+=2

现在我们已经十分清楚了,当remain>=0时,start都要向前走,限制我们计算前面状态的概率,因为这些概率已经没有意义了。下面的代码也是这样描述的:

if(remain >= 0)
    start += s_inc[remain];

那么这个s_inc是什么东西?它就是我们需要提前准备好的计算量。我们知道经过扩充的label序列中,所有的非空label都处在奇数的index上,而填充的blank都处在偶数的index上(我们是0-based的计算方法,matlab选手请退散……),所以对于上面的问题,当start=0时,下一步我们会从0跳到1,此后我们会从1到3,3到5,跳转的步数都是2,所以基于这个思路,我们就可以把s_inc这个数组生成出来。当然,我们的前提是没有重复。下面我们会说重复的问题的。

我们上面说了这么多,重点把start的变化介绍清楚了。下面我们来看看end。其实end的原理也类似,我们还是用刚才的废话套路来介绍站在end视角的世界:

  • 在T=1时,我们最多能到第一个结果
  • 在T=2时,我们最多能转移到第二个结果
  • 在T=3时,我们最多能转移到第三个结果
  • 在T=4时,我们已经掌握了整个世界……oh yeah

好了,可以看出end的变化形式,每个时刻end都可以+2,直到到达最后一个非blank的label,end变成了+1,然后end就不用动了,等着start动就可以了……(怎么感觉有点污?天哪……)

那么end变化的条件是什么呢?

if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我们还是忽略repeats,那么就十分清楚了,如果当前时刻小于等于label数,那么尽管前进,如果大于了,基本上也就到头了,这时候end就不用动了。

好了,前面我们终于说完了简单模式下start和end的移动规律,下面我们来看看带重复模式下的变化方法。

重复,重复

重复会带来什么样的变化呢?说白了如果有重复的label出现,那么两个连续重复的label中间就要至少出现一个blank。换句话说,每出现一个重复,我们的S/2就要加一,于是我们再看一眼这两个计算公式:

int remain = (S / 2) + repeats - (T - t);
if(remain >= 0)
    start += s_inc[remain];
if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我们把repeats和S/2归到一起,这时候就能看明白了。

同理,在计算s_inc和e_inc的时候,由于有repeats的存在,它们从过去的+2变成了两个+1。也就是说先从label跳到blank,再跳到下一个label。这样就可以解释s_inc和e_inc的初始化策略了:

int e_counter = 0;
int s_counter = 0;

s_inc[s_counter++] = 1;

int repeats = 0;

for (int i = 1; i < L; ++i) {
    if (labels[i-1] == labels[i]) {
        s_inc[s_counter++] = 1;
        s_inc[s_counter++] = 1;
        e_inc[e_counter++] = 1;
        e_inc[e_counter++] = 1;
        ++repeats;
    }
    else {
        s_inc[s_counter++] = 2;
        e_inc[e_counter++] = 2;
    }
}
e_inc[e_counter++] = 1;

好了,到此我们才算把CTC中compute ctc loss这部分介绍完了。教科书上的一个公式看着简单,落实到代码就似乎充满了trick。希望看懂了这个计算的你大脑没有阵亡。

posted on 2017-08-29 17:50  azheng333  阅读(5533)  评论(0编辑  收藏  举报