tensorflow如何正确加载预训练词向量

使用预训练词向量和随机初始化词向量的差异还是挺大的,现在说一说我使用预训练词向量的流程。

  一、构建本语料的词汇表,作为我的基础词汇

  二、遍历该词汇表,从预训练词向量中提取出该词对应的词向量

  三、初始化embeddings遍历,将数据赋值给tensor

样例代码:

  

复制代码
 1 #-*- coding: UTF-8 -*-
 2 import numpy as np
 3 import tensorflow as tf
 4 '''本程序只是对word2vec进行了简单的预处理,应用到复杂模型中还需要根据实际情况做必要的改动'''
 5 
 6 class Wordlist(object):
 7     def __init__(self, filename, maxn = 100000):
 8         lines = map(lambda x: x.split(), open(filename).readlines()[:maxn])
 9         self.size = len(lines)
10 
11         self.voc = [(item[0][0], item[1]) for item in zip(lines, xrange(self.size))]
12         self.voc = dict(self.voc)
13 
14     def getID(self, word):
15         try:
16             return self.voc[word]
17         except:
18             return 0
19 
20 def get_W(word_vecs, k=300):
21     """
22     Get word matrix. W[i] is the vector for word indexed by i
23     """
24     vocab_size = len(word_vecs)
25     word_idx_map = dict()
26     W = np.zeros(shape=(vocab_size+1, k), dtype='float32')
27     W[0] = np.zeros(k, dtype='float32')
28     i = 1
29     for word in word_vecs:
30         W[i] = word_vecs[word]
31         word_idx_map[word] = i
32         i += 1
33     return W, word_idx_map
34 
35 def load_bin_vec(fname, vocab):
36     """
37     Loads 300x1 word vecs from Google (Mikolov) word2vec
38     """
39     i=0
40     word_vecs = {}
41     pury_word_vec = []
42     with open(fname, "rb") as f:
43         header = f.readline()
44         print 'header',header
45         vocab_size, layer1_size = map(int, header.split())
46         print 'vocabsize:',vocab_size,'layer1_size:',layer1_size
47         binary_len = np.dtype('float32').itemsize * layer1_size
48         for line in xrange(vocab_size):
49             word = []
50             while True:
51                 ch = f.read(1)
52                 #print ch
53                 if ch == ' ':
54                     word = ''.join(word)
55                     #print 'single word:',word
56                     break
57                 if ch != '\n':
58                     word.append(ch)
59                     #print word
60             #print word
61             if word in vocab:
62                word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
63                pury_word_vec.append(word_vecs[word])
64                if i==0:
65                    print 'word',word
66                    i=1
67             else:
68                 f.read(binary_len)
69        #np.savetxt('googleembedding.txt',pury_word_vec)
70     return word_vecs,pury_word_vec
71 
72 def add_unknown_words(word_vecs, vocab, min_df=1, k=300):
73     """
74     For words that occur in at least min_df documents, create a separate word vector.
75     0.25 is chosen so the unknown vectors have (approximately) same variance as pre-trained ones
76     """
77     for word in vocab:
78         if word not in word_vecs and vocab[word] >= min_df:
79             word_vecs[word] = np.random.uniform(-0.25,0.25,k)
80 
81 if __name__=="__main__":
82     w2v_file = "GoogleNews-vectors-negative300.bin"#Google news word2vec bin文件
83     print "loading data...",
84     vocab = Wordlist('vocab.txt')#自己的数据集要用到的词表
85     w2v,pury_word2vec = load_bin_vec(w2v_file, vocab.voc)
86     add_unknown_words(w2v, vocab.voc)
87     W, word_idx_map = get_W(w2v)
88 
89     '''embedding lookup简单应用'''
90     Wa = tf.Variable(W)
91     embedding_input = tf.nn.embedding_lookup(Wa, [0,1,2])#正常使用时要替换成相应的doc
92 
93     with tf.Session() as sess:
94         sess.run(tf.global_variables_initializer())
95         input = sess.run(Wa)
96         #print np.shape(Wa)
复制代码

 

posted @   今夜无风  阅读(7355)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示