MAML算法概述
MAML算法概述
什么是MAML
1. 论文地址:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
2. 要解决的问题
- 小样本问题
- 模型收敛过慢
3. 算法描述
MAML期望通过训练一组初始化参数
,使得模型透过训练出的初始化参数,未来在少量样本基础上实现快速收敛。该初始化参数 在训练集上未必是最优解,但可以通过训练出的参数在新的任务上快速收敛,找到最优解。
4. V.S. Pre-train
- Pre-train:训练集上的全局最优参数,但放到测试集上未必可以训练出全局最优,可能只会找到局部最优点。
- MAML:在训练集和测试集上未必全局最优参数,但通过少量迭代,便可收敛到全局最优。
算法描述
- 随机初始化一个权重θ
- Setp3 ~ Step10:一个epoch
- 随机采样一个batch的Task
- 遍历所有Task
- 从Support Set中取出一个batch的Task中的Label和Image
- Setp6 ~ Step7:前向传播,计算梯度后反向传播,更新θ′这个权重
- 从Query Set中取出所有Task前向传播,但不更新模型
- Step10:所有Task结束后,计算Loss,计算梯度,更新θ的权重
核心代码
for epoch in range(args.epoch//10000):
# fetch meta_batchsz num of episode each time
db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 30 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0: # evaluation
db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
accs_all_test = []
for x_spt, y_spt, x_qry, y_qry in db_test:
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
accs_all_test.append(accs)
# [b, update_step+1]
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
本文作者:HoroSherry
本文链接:https://www.cnblogs.com/horolee/p/maml.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 【杂谈】分布式事务——高大上的无用知识?