Matching Network算法概述
什么是Matching Network
1. 论文地址:Matching Networks for One Shot Learning
2. 简介:基于Metric Learning
部分思想,使用外部记忆
来增强网络,提高网络的学习能力。
3. 创新点
- 借鉴了
注意力和外部记忆
方面的经验来搭建网络 - 基于meta-learning用task来训练,而不是metric-learning输入固定类别的图片
4. 算法描述
Matching Network有两个输入:
- 输入任务S为一个N-way K-shot的任务(下图中是一个4way 1shot的任务),其中
- 需要预测类别的图片
Matching Network的输出被定义为:
图片
的预测类别
那么Matching Network算法就可以被构建为
其中,参数映射
,即注意力和外部记忆
(1) 注意力
- 简单范式
论文中给了一个简单的注意力范式:
- 余弦距离注意力
直观上想,很容易想到注意力浅层的向量空间
求解两张图片的类似度/距 离。
论文中定义了一个余弦距离注意力:
其中
(2) 外部记忆
作者作者认为上述的余弦注意力定义的时候,(输入任务S中)每个已知标签的输入
对此,作者提出了双向LSTM
来解决这个问题。
5. 网络设计
算法描述
- 将任务S中所有图片
和目标图片 全部通过CNN网络,以获得它们的浅层向量表示,然后将这 个向量进行堆叠 - 将以上堆叠的浅层向量全部输入到双向LSTM中,获得
个输出。然后使用余弦距离判断前 个输出中与最后一个输出之间的相似度 - 根据计算出的相似度,按照任务中
中的标签信息求解目标图片 的类别标签
核心代码
class MatchingNetwork(nn.Module):
def __init__(self, keep_prob, \
batch_size=100, num_channels=1, learning_rate=0.001, fce=False, num_classes_per_set=5, \
num_samples_per_class=1, nClasses = 0, image_size = 28):
super(MatchingNetwork, self).__init__()
"""
Builds a matching network, the training and evaluation ops as well as data augmentation routines.
:param keep_prob: A tf placeholder of type tf.float32 denotes the amount of dropout to be used
:param batch_size: The batch size for the experiment
:param num_channels: Number of channels of the images
:param is_training: Flag indicating whether we are training or evaluating
:param rotate_flag: Flag indicating whether to rotate the images
:param fce: Flag indicating whether to use full context embeddings (i.e. apply an LSTM on the CNN embeddings)
:param num_classes_per_set: Integer indicating the number of classes per set
:param num_samples_per_class: Integer indicating the number of samples per class
:param nClasses: total number of classes. It changes the output size of the classifier g with a final FC layer.
:param image_input: size of the input image. It is needed in case we want to create the last FC classification
"""
self.batch_size = batch_size
self.fce = fce
self.g = Classifier(layer_size = 64, num_channels=num_channels,
nClasses= nClasses, image_size = image_size )
if fce:
self.lstm = BidirectionalLSTM(layer_sizes=[32], batch_size=self.batch_size, vector_dim = self.g.outSize)
self.dn = DistanceNetwork()
self.classify = AttentionalClassify()
self.keep_prob = keep_prob
self.num_classes_per_set = num_classes_per_set
self.num_samples_per_class = num_samples_per_class
self.learning_rate = learning_rate
def forward(self, support_set_images, support_set_labels_one_hot, target_image, target_label):
"""
Builds graph for Matching Networks, produces losses and summary statistics.
:param support_set_images: A tensor containing the support set images [batch_size, sequence_size, n_channels, 28, 28]
:param support_set_labels_one_hot: A tensor containing the support set labels [batch_size, sequence_size, n_classes]
:param target_image: A tensor containing the target image (image to produce label for) [batch_size, n_channels, 28, 28]
:param target_label: A tensor containing the target label [batch_size, 1]
:return:
"""
# produce embeddings for support set images
# (batch_size,shot_num,3,img_size,img_size)
encoded_images = []
for i in np.arange(support_set_images.size(1)):
gen_encode = self.g(support_set_images[:,i,:,:,:])
encoded_images.append(gen_encode)
# produce embeddings for target images
for i in np.arange(target_image.size(1)):
gen_encode = self.g(target_image[:,i,:,:,:])
encoded_images.append(gen_encode)
outputs = torch.stack(encoded_images)
if self.fce:
outputs, hn, cn = self.lstm(outputs)
# get similarity between support set embeddings and target
similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
similarities = similarities.t()
# produce predictions for target probabilities
preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)
# calculate accuracy and crossentropy loss
values, indices = preds.max(1)
if i == 0:
accuracy = torch.mean((indices.squeeze() == target_label[:,i]).float())
crossentropy_loss = F.cross_entropy(preds, target_label[:,i].long())
else:
accuracy = accuracy + torch.mean((indices.squeeze() == target_label[:, i]).float())
crossentropy_loss = crossentropy_loss + F.cross_entropy(preds, target_label[:, i].long())
# delete the last target image encoding of encoded_images
# make the embedding vector for each new target images to be at the end of the list
encoded_images.pop()
return accuracy/target_image.size(1), crossentropy_loss/target_image.size(1)
本文作者:HoroSherry
本文链接:https://www.cnblogs.com/horolee/p/mn.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧