爱嘉牛LA

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::

1.功能

采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量

ps:采用C++版本训练的w2v模型,python的gensim模块读不了。

2.python训练word2vec模型代码

import multiprocessing

from gensim.models.word2vec import Word2Vec, LineSentence

print('开始训练')
train_file = "/tmp/train_data"

model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10)
print('结束')
model.init_sims(replace=True)
model.save('/tmp/emb.bin')

3.tensorflow读取模型可视化

import numpy as np
import tensorflow as tf
import os
from gensim.models.word2vec import Word2Vec
from tensorflow.contrib.tensorboard.plugins import projector

log_dir = '/tmp/embedding_log'
if not os.path.exists(log_dir):
    os.mkdir(log_dir)


# load model
model_file = '/tmp/emb.bin'
word2vec = Word2Vec.load(model_file)

# create a list of vectors
embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32)
for i, word in enumerate(word2vec.vocab.keys()):
    embedding[i] = word2vec[word]

# setup a TensorFlow session
tf.reset_default_graph()
sess = tf.InteractiveSession()
X = tf.Variable([0.0], name='embedding')
place = tf.placeholder(tf.float32, shape=embedding.shape)
set_x = tf.assign(X, place, validate_shape=False)
sess.run(tf.global_variables_initializer())
sess.run(set_x, feed_dict={place: embedding})

# write labels
with open(os.path.join(log_dir, 'metadata.tsv'), 'w') as f:
    for word in word2vec.vocab.keys():
        f.write(word + '\n')

# create a TensorFlow summary writer
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
config = projector.ProjectorConfig()
embedding_conf = config.embeddings.add()
embedding_conf.tensor_name = 'embedding:0'
embedding_conf.metadata_path = os.path.join(log_dir, 'metadata.tsv')
projector.visualize_embeddings(summary_writer, config)

# save the model
saver = tf.train.Saver()
saver.save(sess, os.path.join(log_dir, "model.ckpt"))

print("完成!")

 

posted on 2019-01-04 19:18  爱嘉牛LA  阅读(2516)  评论(0编辑  收藏  举报