pytorch训练skipgram
word2vec.py
import torch import torch.nn.functional as F import numpy as np import time import jieba class SkipGram(torch.nn.Module): def __init__(self, vocab_size, embedding_size): super().__init__() self.vocab_size = vocab_size self.embedding_size = embedding_size self.hidden = torch.nn.Linear(self.vocab_size, self.embedding_size) self.predict = torch.nn.Linear(self.embedding_size, self.vocab_size) # self.w1=torch. # self.w2=torch. def forward(self, X): hidden = self.hidden(X) predict = F.softmax(self.predict(hidden)) return predict def data_iter(words,batch_size=3): '''sentence = 'The quick fox jumps over the lazy dog' words = sentence.split() ''' w_len = len(words) word2id = {words[i]: i for i in range(w_len)} one_hot = np.eye(len(words)) context_size = 1 x = [] y=[] for i in range(w_len): prior = words[i - context_size:i] behind = words[i + 1:i + 1 + context_size] context = prior + behind x.extend([one_hot[i] for c in context])#输入的x是word,此处重复context次,是为了构建context个(word,context)对 y.extend([word2id[c] for c in context])#预测的y是相邻单词的id,因为loss是CrossEntropy idx = 0 while idx < len(x): yield (x[idx:idx + batch_size],y[idx:idx+batch_size]) idx += batch_size def cut_sentence(sentence): word_list=jieba.cut(sentence) return word_list def train(words,batch_size=64): net = SkipGram(len(words),20) optimizer = torch.optim.SGD(net.parameters(), lr=0.2) loss_fun = torch.nn.CrossEntropyLoss() losses = [] for i in range(5): batch_num=0 for x,y in data_iter(words,batch_size): x, y = torch.FloatTensor(x), torch.LongTensor(y) pred = net(x) loss = loss_fun(pred, y) losses.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() batch_num+=1 if batch_num%10==0: print('batch_num %d,loss:%.4f' % (batch_num, sum(losses) / len(losses))) date_str=time.strftime('%Y%m%d%H%M%S',time.localtime()) model_filename='skipgram_{}.pkl'.format(date_str) torch.save(net.state_dict(),'saved_model/{}'.format(model_filename)) print('model is saved as {}'.format(model_filename)) return model_filename def load_corpus(): stop_word = ['【', '】', ')', '(', '、', ',', '“', '”', '。', '\n', '《', '》', ' ', '-', '!', '?', '.', '\'', '[', ']', ':', '/', '.', '"', '\u3000', '’', '.', ',', '…', '?',';','(',')'] f= open('data/corpus.txt','r',encoding='utf-8') text=f.read() f.close() for i in stop_word: text = text.replace(i, "") print(text) return text def test(): text='The quick fox jumps over the lazy dog' words=text.split() text = load_corpus() words=list(cut_sentence(text)) #print(words) w_len = len(words) word2id = {words[i]: i for i in range(w_len)} one_hot = np.eye(len(words)) model_filename = train(words) net = SkipGram(len(words),20) net.load_state_dict(torch.load('saved_model/{}'.format(model_filename))) idx=5 print('word:',words[idx]) print('prediction:') print(net(torch.FloatTensor(one_hot[idx]))) if __name__ == '__main__': test()