思路:
  1.根据变量名称过滤要更新的权重:
  2.如果参数分开更新,还需要设置多个优化器
 
代码示例:
def Net_1(input):
    with tf.variable_scope('Net_1'):
        fmap_input = tf.layers.conv2d(input,32,32,(1,1),padding='same',name='conv1')
        _, xh, xw, xc = fmap_input.get_shape().as_list()
        gap = tf.layers.max_pooling2d(fmap_input,(xh,xw),strides=(1,1),name='gap')
        _,h,w,c = gap.get_shape().as_list()
        gap = tf.reshape(gap,(-1,c))
        cls_logit = tf.layers.dense(gap,3,name='fc')
        cls_probs_soft = tf.nn.softmax(cls_logit, axis=1)
        cls_probs = tf.clip_by_value(cls_probs_soft, 1e-7, 1.0)
  。。。。
  return tmp
 
def Net_2(input_fm):
    with tf.variable_scope('Net_2'):
        _, xh, xw, xc = input_fm.get_shape().as_list()
        gap = tf.layers.max_pooling2d(input_fm, (xh, xw),strides=(1,1), name='gap_cls_head')
        _,h,w,c = gap.get_shape().as_list()
        gap = tf.reshape(gap,(-1,c))
        fc_cls_logits = tf.layers.dense(gap,3)
        cls_probs_soft = tf.nn.softmax(fc_cls_logits, axis=1)
        cls_probs = tf.clip_by_value(cls_probs_soft, 1e-8, 1.0)
        return cls_probs
 
net1的输出作为net2的输入
input_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, 128, 128, 3], name='input_')
gt = tf.placeholder(dtype=tf.int32, shape=[None], name='label')
global_step = tf.Variable(0, name='globel_step', trainable=False)
 
output1 = Net_1(input_placeholder)
output2 = Net_2(output1)
loss=损失(output1,gt)
loss1=损失(output2,gt)

#net1更新
optimizer_anet = tf.train.AdamOptimizer(0.01)
#法1
net1_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_1')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#net1_vars = [vr for vr in tvars1 if 'Net_1' in vr.name]
for tmp in net1_vars:
print('net1--->',tmp.name)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op_net1 = optimizer_anet.minimize(loss, global_step=global_step, var_list=net1_vars)

#net2更新
optimizer = tf.train.AdamOptimizer(0.01)
#法1
other_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_2')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#other_vars = [vr for vr in tvars1 if 'Net_2' in vr.name]
for tmp in other_vars:
print('net_2-->',tmp.name)
update_ops1 = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops1):
train_op_net2 = optimizer.minimize(loss2,global_step=global_step,var_list=other_vars)

    
with tf.Session() as sess:
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
 #只更新net1
   _=sess.run(train_op_net1,feed_dict={input_placeholder:数据,label:数据})
 #只更新net2
 _=sess.run(train_op_net2,feed_dict={input_placeholder:数据,label:数据})
posted on 2023-05-08 18:51  一点飞鸿  阅读(108)  评论(0编辑  收藏  举报