DL学习-ctc解码
参考基于CTC的序列模型:https://distill.pub/2017/ctc/
ctc解码方式:
- Greedy decode,每次都选取概率最大。
- Beam Search,对规整字符串进行束搜索算法。
- FST Status Encode
对齐方式:
方案1:为每个输入步骤分配一个输出字符,堆叠重复的字符。
方案2:为字符添加blank用于防止hello被误解码为helo,防止output中的重复字符被折叠。
特点:1.对齐具有单调性;2.X->Y的映射是多个\(x_i\)对应一个\(y_i\);3.Y的长度不能比X长。
流程:
为了方便就直接截图了,后面其实自己结合自己理解画了一个,写完附在最后面吧。
简单的beam-search算法:
输入:
- 模型的输出y,shapes:[sequence_length,num_outsize]
- 束的长度beam_size,表示搜索的宽度
输出:
- 最优的beam_size条路径。
算法思路:
1.对y求log。
2.初始化beam为[([],0)]
3.迭代sequence中的每一个下标,根据log y的数值确定出每个前缀的分数,并采用排序的方式最终获取一个最优的beam_size条路径。
Prefix Beam Search
结合动态规划以及Beam-Search进行搜索。
需要实现的子函数:
- norm_prefix:将相邻两个相同的字符去重,并将非结尾部分的blank删除。例如:[1,1,2]->[1,2] [1,1,]->[1,_]
这个图是演示的相同路径下进行归并,即相同的节点进行合并。
官方给的代码为:
"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
'''这段代码定义了一个名为make_new_beam的函数,该函数没有参数。这个函数的主要功能是创建并返回一个默认字典(defaultdict)。在Python中,defaultdict是内置dict类的一个子类,它重写了一个方法并添加了一个可写的实例变量。其主要特点是在直接读取dict不存在的属性值时,直接返回默认值。
fn = lambda : (NEG_INF, NEG_INF):这行代码定义了一个匿名函数(lambda函数),这个函数没有参数,每次被调用时都会返回一个元组(NEG_INF, NEG_INF)。这里的NEG_INF可能是一个在其他地方定义的常量,表示负无穷大。
return collections.defaultdict(fn):这行代码创建了一个defaultdict,它的默认值是由上面定义的fn函数生成的。也就是说,当你试图访问这个字典中不存在的键时,它会返回(NEG_INF, NEG_INF)。
总的来说,make_new_beam函数返回的是一个默认值为(NEG_INF, NEG_INF)的字典。
'''
def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
'''
这段代码定义了一个名为logsumexp的函数,该函数接收任意数量的参数。这个函数的主要功能是计算所给参数的log-sum-exp,这是一种在处理概率等涉及指数运算的数值时常用的技巧,可以提高数值稳定性,防止因数值过大或过小导致的溢出或下溢。
if all(a == NEG_INF for a in args): return NEG_INF:这行代码检查所有的输入参数是否都等于NEG_INF(可能是一个在其他地方定义的表示负无穷大的常量)。如果所有参数都是NEG_INF,那么函数直接返回NEG_INF。
a_max = max(args):这行代码找出所有输入参数中的最大值。
lsp = math.log(sum(math.exp(a - a_max) for a in args)):这行代码首先计算每个输入参数与最大值的差的指数,然后将这些指数值求和,最后对求和结果取对数。这是log-sum-exp的关键步骤,通过减去最大值,可以防止指数运算的结果过大导致溢出。
return a_max + lsp:最后,函数返回最大值与上一步计算的对数求和结果的和。这就是log-sum-exp的结果。
总的来说,logsumexp函数实现了一种数值稳定的方式来计算一组数的log-sum-exp。
'''
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=100, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)
'''
在CTC(Connectionist Temporal Classification)损失中,`p_blank`和`p_no_blank`是两个关键的概率值,它们分别表示在某个时间步上预测出空白标签(blank label)和非空白标签(non-blank label)的概率。
- `p_blank`:这个概率值通常由神经网络模型直接输出,表示在当前时间步上预测出空白标签的概率。在CTC中,空白标签是一个特殊的标签,用于表示没有输出或者输出与前一个时间步的输出相同。
- `p_no_blank`:这个概率值通常由1减去所有空白标签的概率得到,表示在当前时间步上预测出任何非空白标签的概率。
这两个概率值的确定通常依赖于你的神经网络模型的输出。具体的计算方法可能会根据你的模型和任务有所不同。例如,如果你的模型输出的是每个标签的概率,那么你可以直接使用这些概率作为`p_blank`和`p_no_blank`。如果你的模型输出的是每个标签的logits(即未经softmax或sigmoid函数处理的原始输出),那么你可能需要先将这些logits转换为概率,然后再计算`p_blank`和`p_no_blank`。
'''
# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # Loop over time
# A default dictionary to store the next step candidates.
next_beam = make_new_beam()
for s in range(S): # Loop over vocab
p = probs[t, s]
# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
if s == blank:
'''
在CTC(Connectionist Temporal Classification)解码中,n_p_b, p_b + p, p_nb + p 是用于计算新的空白和非空白概率的中间变量。
n_p_b:这个变量表示新的空白概率,它是由当前时间步的空白概率(p_b)和非空白概率(p_nb)相加得到的。这个变量的计算反映了CTC解码的一个关键思想,即在当前时间步预测出空白标签可以由前一个时间步预测出空白标签或非空白标签两种情况转移得到。
p_b + p:这个变量表示当前时间步预测出空白标签的概率(p_b)和当前时间步的模型输出概率(p)的和。这个变量用于更新p_b,即新的空白概率。
p_nb + p:这个变量表示当前时间步预测出非空白标签的概率(p_nb)和当前时间步的模型输出概率(p)的和。这个变量用于更新p_nb,即新的非空白概率。
这三个变量的计算是CTC解码的关键步骤,它们反映了CTC解码的主要思想,即通过动态规划在所有可能的序列中找到最可能的序列。
'''
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# Extend the prefix by the new character s and add it to the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# We don't include the previous probability of not ending in blank (p_nb) if s is repeated at the end. The CTC algorithm merges characters not separated by a blank.
# 不能在结尾连着出来俩blank
n_p_nb = logsumexp(n_p_nb, p_b + p)
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb)
# If s is repeated at the end we also update the unchanged prefix. This is the merging case.
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
# 返回的是数值,也可以通过best来获取其中的最优字符串
return best[0], -logsumexp(*best[1])
if __name__ == "__main__":
np.random.seed(3)
time = 50
output_dim = 20
probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)
labels, score = decode(probs)
print("Score {:.3f}".format(score))
手语翻译采用的束搜索算法:
这段代码定义了一个名为search的函数,它实现了基于概率的搜索算法。
这个函数接受五个参数:probs,beam_width,prune,blank和lm。其中,probs是一个概率分布,beam_width是搜索宽度,prune是剪枝阈值,blank是空白符号的索引,lm是语言模型函数。
函数首先检查lm是否为None,如果是,则将其设置为一个返回1的函数。
然后,函数定义了一个名为mslm的内部函数,用于计算语言模型的概率。
接下来,函数初始化了两个字典p_b和p_nb,用于存储空白和非空白的概率。
然后,函数进入一个循环,对probs中的每个元素进行处理。在每次迭代中,函数首先找出概率大于prune的状态,然后对每个可能的状态计算其概率,并更新p_b和p_nb。
在循环结束后,函数将p_b和p_nb相加,得到总的概率分布p。
然后,函数对p进行排序,取出概率最高的beam_width个元素作为新的前缀。
最后,函数对概率进行归一化处理,避免下溢,并返回最有可能的假设。
def search(self,probs,beam_width: int = 10,prune: float = 1e-2,blank: int = 0,lm=None,alpha=0.3):
if lm is None:
lm=lambda *_:1
def mslm(l):
if len(l)==1:
return self.is_begining(l[-1])
a,b=l[-2:]
if self.is_next(a,b):
return 1
elif self.is_exiting(a,b):
return lm(self.collapse(l))**alpha
return 0
p_b = defaultdict(Counter)
p_nb = defaultdict(Counter)
p_b[-1][()] = 1
p_nb[-1][()] = 0
prefixes = [()]
for t in range(len(probs)):
pruned_states, prune_relaxed = [], prune
while not pruned_states:
pruned_states = np.where(probs[t] >= prune_relaxed)[0].tolist()
prune_relaxed /= 2
pruned_states = set(pruned_states)
for l in prefixes:
possible_states = {blank} | pruned_states
if l:
possible_states |= self.successors(l[-1])
for s in possible_states:
p_t_s = probs[t,s]
if s == blank:
p_b[t][l] += p_t_s * (p_b[t - 1][l] + p_nb[t - 1][l])
continue
ls = l + (s,)
p_lm = mslm(ls)
if l and s == l[-1]:
# a_ + a = aa
p_nb[t][ls] += p_lm * p_t_s * p_b[t - 1][l]
# a + a = a
p_nb[t][l] += p_t_s * p_nb[t - 1][l]
else:
# a(_) + b = ab
p_nb[t][ls] += p_lm * p_t_s * (p_b[t - 1][l] + p_nb[t - 1][l])
p = p_b[t] + p_nb[t]
if len(p) == 0:
p = p_b[t] # 0 prob for all prefix
if len(p) == 0:
p = p_nb[t] # 0 prob for all prefix
prefixes = sorted(p, key=lambda k: p[k], reverse=True)
prefixes = prefixes[:beam_width]
# divide by a constant (min_prob) to avoid underflow
min_prob = np.inf
for prefix in prefixes:
if min_prob > p[prefix] and p[prefix] > 0:
min_prob = p[prefix]
for prefix in prefixes:
# usually, min_prob won't be zero
p_b[t][prefix] /= min_prob
p_nb[t][prefix] /= min_prob
if p[prefixes[0]] == 0:
raise ValueError("Even the most probable beam has probability 0. ")
hyp = self.collapse(prefixes[0])
return hyp
随手画的笔记
结语
读完还是很有收获的,英语阅读能力有待提高,读的时候累死了,对于ctcloss也不是云里雾里了,估计应付面试没大问题。