ECAPA-TDNN结合代码的理解
ECAPA-TDNN网络架构被分成了三个小节,下面将对ecapa-tdnn模型架构以及代码进行详细分析。
-
依赖通道和时序的统计池化。
-
通道注意力模块
-
多层特征的聚合
建模通道和时序依赖关系的统计池化
其本质是将时序注意力机制延伸到通道注意力,形成通道-时序注意力方法。在此之前需要先了解一下文中提到的 soft self attention。
soft self attention
假设经过一些列特征的抽象后,最终得到的特征张量维度为[N,T,D],其中N为batch size;T是时间维度,也就是帧数;D是维度,是每帧有多少维数表示。
soft attention的重点是获取时间维度上的依赖关系,可以自己决定哪个时间片段的作用大一些,哪一个时间片段的作用小一些。
首先将输入的数据进行线性变换,获得维度为[N,T,D]的输出张量。
然后为了能让网络自适的获取不同时间片段的权重,需要利用一组可以学习的网络节点,将输出张量的维度变成[N,T],经过softmax后,就获得了与时间相关的权重。
权重来自于输入,最终还是要跟输入相乘,实施自注意力选择机制。输入的维度是[N,T,D],但是权重的维度是[N,T],如何让二者相乘呢?
在这里涉及到一个比较关键的点,那就是soft attention默认当前时间步上的所有通道的权重都是一样的。所以获取完权重后,只需要将当前时间片的权重复制D次,使当前时间步上的权重数量等于通道的数量。这样即可将权重的维度从[N,T]变换成[N,T,D]。
input*weight后,其输出为[N,T,D],在时间维度上实施了注意力机制。
最后获取var和mean,消减时间步。将mean和var进行concat,输出的维度变成[N,D*2]
可结合代码进行更加具象的理解。
class Classic_Attention(nn.Module):
"""
获取时序注意力权重的类
"""
def __init__(self,input_dim, embed_dim, attn_dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.attn_dropout = attn_dropout
self.lin_proj = nn.Linear(input_dim,embed_dim)
self.v = torch.nn.Parameter(torch.randn(embed_dim))
def forward(self,inputs):# [N,T,D]
lin_out = self.lin_proj(inputs) # 线性变换。输出维度:[n,t,d]
v_view = self.v.unsqueeze(0).expand(lin_out.size(0), len(self.v)).unsqueeze(2) # 增加网络节点,维度为[N,d,1]
# .bmm 进行相乘。[N,T,D]*[N,D,1]=[N,T,1]
# BMM:https://pytorch.org/docs/stable/generated/torch.bmm.html
attention_weights = F.tanh(lin_out.bmm(v_view).squeeze()) # 获得权重,维度为:[N,T]
attention_weights_normalized = F.softmax(attention_weights,1)
return attention_weights_normalized
def stat_attn_pool(self,inputs,attention_weights):
el_mat_prod = torch.mul(inputs,attention_weights.unsqueeze(2).expand(-1,-1,inputs.shape[-1]))
# 增加维度:attention_weights.unsqueeze(2),增加后维度为[N,T,1];
# 扩充维度:expand(-1,-1,inputs.shape[-1]),扩充后维度为[N,T,512],将第三个维度上的数据复制512份,这样每个时间步上,不同通道的权重是一样的
mean = torch.mean(el_mat_prod,1) # 获取均值,消除时间维度
variance = self.weighted_sd(inputs,attention_weights,mean)
stat_pooling = torch.cat((mean,variance),1)
return stat_pooling # 输出维度[N,D*2]
Channel- and context-dependent statistics pooling
soft attention获取的的权重,在某个时间步上,不同通道的权重都是一样的。能不能在某个时间步上,不同的通道的权重是不一样的?当然可以,这就是增加了通道注意力的上下文注意力机制。
为了让注意力机制获取语料的全局特性,作者对时序语料进行了扩展。即在计算权重前,首先将输入与输入的全局mean,var进行拼接。
# 将输入x,mean,var按照时间维度进行拼接。
global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1)
下面的池化层,不仅依赖上下文(时间步),同时依赖通道,这样可以获取更多的说话人特征。
输入特征首先经过一个线性影射,将其投影到一个较低的维度上,所有的通道共享权重,以减少参数量,并防止过拟合。
接着经过tanh非线性变换,一个线性映射将维度复原,就获得了依赖通道的注意力得分.这个得分的维度和输入的维度是一样的,都是[N,D,T],这样每个通道的权重就不一样了。对得分进行softmax变换即可获得权重。
将权重与输入相乘,获取mean,var,将二者拼接,即可获得输出[N,D*2]
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, bottleneck_dim=128):
super().__init__()
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x): # 输入x维度:(N,D,T)
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x)) # (N,d,T)
alpha = torch.softmax(self.linear2(alpha), dim=2) # (N,D,T)
mean = torch.sum(alpha * x, dim=2) # (N,D)
residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9)) # 获取std
return torch.cat([mean, std], dim=1) # (N,D*2)
以上就是依赖通道和时序的注意力机制,接下来看一下如何将视觉的SE模块应用进来。
SE模块-通道注意力机制
图像中的SE模块
图像中SE模块的作用是,建模特征通道之间的相互依赖关系来提高网络的表达能力。
技术手段为:利用输入特征获取每个通道的权值,利用其权值大小使网络关注到更为重要的信息。(跟上述的通道上下文池化是同样的道理)
SE模块的组成:global average pooling - FC - ReLU - FC- sigmoid
全局平均池化:通过 squeeze 压缩操作,将跨空间维度H*W的特征映射进行聚合,生成一个通道描述符,HxWxC → 1x1xC
第一个FC用来降维,第二个FC用来增维,ReLU函数用于限制模型的复杂性和帮助训练。
门控机制使用sigmod,每个通道通过一个 基于通道依赖 的自选门机制 来学习特定样本的激活,使其学会使用全局信息,有选择地强调信息特征,并抑制不太有用的特征
论文中的SE模块
论文中提出了一维的SE模块,仿照图像中的SE模块,首先获取通道描述符,获得一个向量。向量中每个值的获取方式为:按时间维度计算当前通道的和,然后求均值,这是个数值,所有通道的计算结果组合在一起,形成一个向量。其维度为[1,1,C]。
然后获得权重(参考上述池化层注意力)。
文中利用sigma激活函数将输出缩放到0-1,作为权重。然后与输入相乘,便可以得到输出。这样就可以获得通道之间的相互依赖关系。
将SE用在了残差模块中:升维-膨胀卷积-降维-SEblock。
多层特征的聚合
文中的第三个改进就是进行了多特征的融合。
浅层的特征对于提取较为鲁棒的向量也是有贡献的。因此将之前所有SE-RES2BLOCK层的输出进行拼接。
x = torch.cat((x1,x2,x3),dim=1)
另外作者还将之前所有SE-Res2Blocks的输出作为当前网络层的输入。
x1 = self.layer1(x)
x2 = self.layer2(x+x1) # 将上一层的输出与输入相加,作为当前网络层的输入。
x3 = self.layer3(x+x1+x2)
LOSS
AAMLOSS
采用了AAMloss,何为AAMloss?
使用AAMloss可以让模型学到的人声特征在不同的人之间的差异性变得更大,相同人之间的差异性变得更小一点。
经过损失函数形式、欧式空间变换、参数归一化和省略b之后,假设只有p1和p2两个类别,那么这两个类别的决策平面是p1=p2,也就是二者的输出概率相等时的值就是决策平面,如果想p1的概率大于p2的概率,那么就需要p1的角度,小于p2的角度。cos是减函数。通过将p1的角度
后续的A-softloss,AM-softloss,AAM-softloss都是围绕下述公式进行优化的:
想要实现中间的这个小于
其不断优化的理由是为了更好的收敛,以及有一个更直观的解释。
以上就是对ecapa-tdnn论文的理解,如有问题,欢迎探讨~
附:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异