tensorflow 使用预训练好的模型的一部分参数

   

vars = tf.global_variables()

net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name and 'word_embedding1s' not in var.name

and 'proj_secondLayer' not in var.name

]

   

saver_pre = tf.train.Saver(net_var)

   

saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

   

'''

with tf.variable_scope('bi-lstm',reuse=True):

fwk=tf.get_variable('bidirectional_rnn/fw/lstm_cell/kernel')

fwb=tf.get_variable('bidirectional_rnn/fw/lstm_cell/bias')

bwk = tf.get_variable('bidirectional_rnn/bw/lstm_cell/kernel')

bwb = tf.get_variable('bidirectional_rnn/bw/lstm_cell/bias')

   

saver_pre= tf.train.Saver({'words/_word_embeddings':self._word_embeddings,

'bi-lstm/bidirectional_rnn/fw/lstm_cell/kernel':fwk,

'bi-lstm/bidirectional_rnn/fw/lstm_cell/bias':fwb,

'bi-lstm/bidirectional_rnn/bw/lstm_cell/kernel':bwk,

'bi-lstm/bidirectional_rnn/bw/lstm_cell/bias':bwb})

 

for x in tf.trainable_variables():

print(x.name)

   

#mysaver = tf.train.import_meta_graph(self.config.dir_model_storepath_pre_graph)

   

saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

'''

posted @   simple_wxl  阅读(2279)  评论(0编辑  收藏  举报
编辑推荐:
· Linux glibc自带哈希表的用例及性能测试
· 深入理解 Mybatis 分库分表执行原理
· 如何打造一个高并发系统?
· .NET Core GC压缩(compact_phase)底层原理浅谈
· 现代计算机视觉入门之:什么是图片特征编码
阅读排行:
· 手把手教你在本地部署DeepSeek R1,搭建web-ui ,建议收藏!
· Spring AI + Ollama 实现 deepseek-r1 的API服务和调用
· 数据库服务器 SQL Server 版本升级公告
· C#/.NET/.NET Core技术前沿周刊 | 第 23 期(2025年1.20-1.26)
· 程序员常用高效实用工具推荐,办公效率提升利器!
点击右上角即可分享
微信分享提示