优化计算速度
下面的代码是参照百度的warp-ctc实现,为了方便理解将原始的C++转换成了python。

(注:很多博客里的这张图都是错的,会出现跳过真实标签直接将两个空格连接起来的情况,同一行连续空格是允许的)
如图是T=12,label='state'全部可能的路线,可以看出在不同的timestep,路线可能的选点位置是有限制的。比如t=1时,只能选前两个点,也就是和
。在t=5时,可以选择
到
间的任意一个点。因此我们在计算前向递推概率和
和后向递推概率和
的过程中,不需要遍历
中的每一个点,可以大大减小计算数量。
举个简单的例子,假如T=3,label='cat',这时只有一条路线,t=1,2,3时,分别只能选'c','a','t'。假如T=4,lable='cat',当t=1时,可以选择blank或'c'。t=2时,选择的起点至少为'c',如果这一步还选不到'c'的话,剩下的两个timestep是不够选'cat'三个字符的。

接下来继续以图1中T=12,label='state'的例子来解释代码
label = 'state'
L = len(label)
S = 2 * L + 1
T = 12
repeats = 0
start = 0 if S//2 + repeats - T < 0 else 1
end = 2 if S > 1 else 1 # end取不到,只能取到end-1
首先在label的首尾以及每两个字符间加上blank,长度变成2*L+1=11。接下来计算第一个timestep的选点的start和end位置,其中S//2得到的是插入空格前的label原始长度5,最终在对结果进行转换时要去掉连续重复字符和blank,如果gt中有连续重复字符,识别结果中重复字符间要有blank才能转换回去,例如'pp'转换为'p','p-p'转换为'pp'。repeats实际上是需要在label中连续重复字符中间添加的blank数量,在'state'中没有连续重复字符,因此repeats=0;在'apple'中,需要在'pp'中间添加一个blank,因此repeats=1;在'appple'中,需要在'ppp'中间添加两个blank,因此repeats=2。
当S//2 + repeats - T < 0时,start=0也就是blank,此时T大于所有必要的选点数S//2 + repeats,也就是有冗余的情况,我们可以在多余的timestep中插入blank或者连续的字符,在转换时这些blank以及连续的字符都会被删掉。另一种情况是S//2 + repeats - T = 0,此时只有一条路线,如上面label='cat',T=3的例子。不可能出现S//2 + repeats - T > 0的情况。代码中得到的区间为[start, end),左闭右开。
s_inc = [1]
e_inc = []
repeats = 0
label = 'state'
# label = 'satte'
L = len(label)
for i in range(1, L):
if label[i-1] == label[i]:
s_inc.append(1)
s_inc.append(1)
e_inc.append(1)
e_inc.append(1)
repeats += 1
else:
s_inc.append(2)
e_inc.append(2)
e_inc.append(1)
print(s_inc) # [1, 2, 2, 2, 2]
print(e_inc) # [2, 2, 2, 2, 1]
# label='satte'
print(s_inc) # [1, 2, 2, 1, 1, 2]
print(e_inc) # [2, 2, 1, 1, 2, 1]
上面代码中的s_inc和e_inc分别表示起点和终点从当前位置到下一个必要字符的步长,这里的必要字符指的是label里的字符以及重复字符间的blank,例如label=’state',则必要字符就是'state',如label='apple',则必要字符为'ap-ple'。代码中label='state',start起始位置为第一个空格,第一个空格到第一个字符's'的步长为1,因为在label的两端和字符间添加了空格,因此从字符's'到第二个字符't'的步长为2,依次类推。
start的起始位置为第一个空格,end的起始位置为第一个字符,start第一个步长为1从第一个空格到第一个字符,end第一个步长为2从第一个字符到第二个字符;start的终止位置为最后一个字符,end的终止位置为最后一个空格,start最后一个步长为2从倒数第二个字符到最后一个字符,end最后一个步长为1从最后一个字符到最后一个空格。
当label中出现连续重复字符时,start和end都必须从第一个字符到空格再到第二个字符,因此连续两个步长为1。
下面是实际计算每个timestep的start和end位置的代码
for t in range(1, T):
remain = S // 2 + repeats - (T - t)
if remain >= 0:
start += s_inc[remain]
if t <= S//2 + repeats:
end += e_inc[t-1]
首先只要remain<0意味着剩下的timestep可以覆盖label,也就是从头到尾走完label,可以参照图1理解,这时start可以一直保持在第一个空格的位置不动。一旦remain=0即到达临界点,此时剩下的timestep刚好可以走完label,但只有一条路线,每个timestep都经过必要字符,不允许有多余的blank和重复字符了。s_inc是我们前面计算出的起点经过每个必要字符的步长,一旦remain>=0后start依次加s_inc里的每个步长就可以了。
end的位置是每个timestep所能达到的最远位置,而最快的路线就是从第一个timestep开始每次都经过必要字符,中间没有多余的blank和重复,因此end从第一个timestep就依次加e_inc里的步长就可以了。
对数域优化
(一)
对任意一条路径,有

其中是介于0到1之间的概率,如果T非常大,几百个小于1的浮点数相乘可能导致underflow,因此我们将计算过程转到对数域上,连乘变成了连加
(二)

(三)
如果有N个概率,我们要求其对数域之和

如果很大,
会很大导致上溢出或下溢出, 优化方法如下

其中取
中的最大项,证明如下

求的代码如下,将CTC Loss(一)里的alpha_vanilla转换到了对数域上
def alpha(log_y, labels):
T, V = log_y.shape
L = len(labels)
log_alpha = np.ones([T, L]) * -np.float('inf')
# init
log_alpha[0, 0] = log_y[0, labels[0]]
log_alpha[0, 1] = log_y[0, labels[1]]
for t in range(1, T):
for i in range(L):
s = labels[i]
a = log_alpha[t - 1, i]
if i - 1 >= 0:
a = logsumexp(a, log_alpha[t - 1, i - 1])
if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
a = logsumexp(a, log_alpha[t - 1, i - 2])
log_alpha[t, i] = a + log_y[t, s]
return log_alpha
其中,函数logsumexp代码如下
def logsumexp(a, b):
"""
np.log(np.exp(a) + np.exp(b))
"""
if a < b:
a, b = b, a
if b == -np.float('inf'):
return a
else:
return a + np.log(1 + np.exp(b - a))
CTC Loss(一)中计算梯度的函数gradient转化后如下
def backward(log_y, labels):
T, V = log_y.shape
log_alpha = alpha(log_y, labels)
log_beta = beta(log_y, labels)
log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])
log_grad = np.ones([T, V]) * -np.float('inf')
for t in range(T):
for s in range(V):
lab = [i for i, c in enumerate(labels) if c == s]
for i in lab:
log_grad[t, s] = logsumexp(log_grad[t, s],
log_alpha[t, i] + log_beta[t, i])
log_grad[t, s] -= 2 * log_y[t, s]
log_grad -= log_p
return log_grad
参考
https://zhuanlan.zhihu.com/p/23309693
本文介绍了如何优化CTC Loss计算速度,通过分析Timestep选点限制减少计算量,并详细解释了对数域优化的原理,包括避免浮点数下溢和上溢的方法。此外,提供了计算alpha和gradient的转换代码,便于理解CTC Loss的实现。
&spm=1001.2101.3001.5002&articleId=117418563&d=1&t=3&u=f6d7eab9d89a4985a0a3b48896d664d1)
2664

被折叠的 条评论
为什么被折叠?



