“Triplet network”三元组网络阅读笔记
记录《DEEP METRIC LEARNING USING TRIPLET NETWORK》阅读笔记
文章总体内容:
作者在前人提出的多个特征提取方法的基础上提出Triplet network模型,通过比较距离来学习有用的变量(深度学习中拟合出函数),在多个不同的数据集显示Triplet network比直接计算方法的Siamese network模型效果更好。
Triplet network基本原理:
在Siamese network中,会出现如下的问题,当使用随机对象的数据集时,一个对象可能被认为与另一个对象相似,但是当我们只想区分一组个体中的两个对象时,可能被认为与同样的另一个对象不相似。当选取特征时,并不能够足够判断两者之间的关系,在面对训练样本数量较少的简单分类问题,可能会产生误差。 因此,作者提出了Triplet network,利用三个样本组成一个训练组,从中获取拟合函数。
其基本结构如图1所示:
图1 Triplet network基本结构图
Triplet network由3个具有相同前馈网络(共享参数)组成。 接收到3个样本时,网络输出2个中间值表示与第三个变量之间的欧式距离。3个输入表示为x,x+和x-,并将网络的嵌入层表示表示为Net(x)。 简单来说,triplet是一个三元组,这个三元组是这样构成的:从训练数据集中随机选一个样本,该样本称为Anchor,然后再随机选取一个和Anchor (记为x)属于同一类的样本和不属于同一类的样本,这两个样本对应的称为Positive (记为x+)和Negative (记为x-),由此构成一个(Anchor,Positive,Negative)三元组。他们之间的关系用欧氏距离表示,并通过训练参数使得x向x+靠近,远离x-,从而实现分类任务。
图2 triplet示意图(图片来自网络)
Triplet距离与目标函数:
距离采用的计算为欧氏距离,即L-2范式距离,如下所示:
比较器即是对上述该向量进行处理,训练的loss可规定如下,针对三元组中的每个元素(样本),训练一个参数共享的网络,得到三个元素的特征表达,
分别记为:
通过训练,让x+和x特征表达之间的距离尽可能小,而和x-和x的特征表达之间的距离尽可能大,并且要让x+和x之间的距离,和x-和x之间的距离之间有一个最小的间隔,在本文中使用的间隔为1。即:
其中:
通过d+和d-可以把初始变量归一化到(0,1)范围内。
论文创新点:
1. 由于Triplet network模型允许通过比较样本而不是直接数据标签进行学习,因此可以将其用作无监督学习模型。
2. 由于Triplet network模型采用三元组作为训练样本,在数据量较少的简单分类任务中表现要准确得多。在多个数据集上取得更优的分类结果。