FSL-GNN代码解读

FSL-GNN代码解读

main.py(主函数)

1、加载数据集:

train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset)

2、初始化或加载模型:

enc_nn = models.load_model('enc_nn', args, io)
metric_nn = models.load_model('metric_nn', args, io)

if enc_nn is None or metric_nn is None:
	enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

models.create_models(args=args) : in models.py

def create_models(args):
    print (args.dataset)
    if 'omniglot' == args.dataset:
        enc_nn = EmbeddingOmniglot(args, 64)
    elif 'mini_imagenet' == args.dataset:
        enc_nn = EmbeddingImagenet(args, 128)
    else:
        raise NameError('Dataset ' + args.dataset + ' not knows')
    return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size)

class EmbeddingOmniglot():				# 特征提取
class EmbeddingImagenet():				# 略

class MetricNN(nn.Module):
	if self.metric_network == 'gnn_iclr_nl':……		# 正常的网络
	self.gnn_obj = gnn_iclr.GNN_nl()			# in gnn_iclr.py
	
	elif self.metric_network == 'gnn_iclr_active':……	# 主动学习
	self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py
	
class SoftmaxModule():		# 线性分类

class GNN_nl(nn.Module) & class GNN_active(nn.Module) : in gnn_iclr.py

class GNN_nl(nn.Module):		# 图网络主要部分
	class Wcompute(nn.Module)	# W邻接矩阵计算
    class Gconv(nn.Module)		# 组图
		def gmul(input)		# 更新图节点特征,W直接返回

3、训练

# 权重衰减
weight_decay = 1e-6

# 优化器
opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay)

# 梯度置零,也就是把loss关于weight的导数变成0
opt_enc_nn.zero_grad()
opt_metric_nn.zero_grad()

# 训练
loss_d_metric = train_batch(
	model=[enc_nn, metric_nn, 
	softmax_module],
	data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])

# 更新参数
opt_enc_nn.step()
opt_metric_nn.step()

# 自适应参数
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)

# 显示训练中loss的更新
if batch_idx % args.log_interval == 0:
	display_str = 'Train Iter: {}'.format(batch_idx)
	display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter)
	io.cprint(display_str)

# 测试
def test_one_shot(args, model, test_samples=5000, partition='test') 定义于 test.py 中
val_acc_aux = test.test_one_shot	# 验证集上测试
test_acc_aux = test.test_one_shot	# 测试集上测试
test.test_one_shot(					# 训练集上测试
	args, 
	model=[enc_nn, metric_nn, softmax_module],
	test_samples=test_samples, 
	partition='train')				

# 测试完毕,将模型设置回训练状态
enc_nn.train()
metric_nn.train()

# 若在验证集上的效果继续变好,则更新
if val_acc_aux is not None and val_acc_aux >= val_acc:

# 保存模型
torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name)

# 全部训练完毕后进行测试
test.test_one_shot
posted @ 2021-10-06 14:45  SethDeng  阅读(420)  评论(0编辑  收藏  举报