pytorch训练skipgram

word2vec.py

import torch
import torch.nn.functional as F
import numpy as np
import time
import jieba


class SkipGram(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden = torch.nn.Linear(self.vocab_size, self.embedding_size)
        self.predict = torch.nn.Linear(self.embedding_size, self.vocab_size)

        # self.w1=torch.
        # self.w2=torch.

    def forward(self, X):
        hidden = self.hidden(X)
        predict = F.softmax(self.predict(hidden))
        return predict


def data_iter(words,batch_size=3):
    '''sentence = 'The quick fox jumps over the lazy dog'
    words = sentence.split()
    '''
    w_len = len(words)
    word2id = {words[i]: i for i in range(w_len)}
    one_hot = np.eye(len(words))

    context_size = 1
    x = []
    y=[]
    for i in range(w_len):
        prior = words[i - context_size:i]
        behind = words[i + 1:i + 1 + context_size]
        context = prior + behind
        x.extend([one_hot[i] for c in context])#输入的x是word,此处重复context次,是为了构建context个(word,context)对
        y.extend([word2id[c] for c in context])#预测的y是相邻单词的id,因为loss是CrossEntropy

    idx = 0
    while idx < len(x):
        yield (x[idx:idx + batch_size],y[idx:idx+batch_size])
        idx += batch_size

def cut_sentence(sentence):
    word_list=jieba.cut(sentence)
    return word_list

def train(words,batch_size=64):
    net = SkipGram(len(words),20)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
    loss_fun = torch.nn.CrossEntropyLoss()
    losses = []
    for i in range(5):
        batch_num=0
        for x,y in data_iter(words,batch_size):
            x, y = torch.FloatTensor(x), torch.LongTensor(y)
            pred = net(x)
            loss = loss_fun(pred, y)
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_num+=1
            if batch_num%10==0:
                print('batch_num %d,loss:%.4f' % (batch_num, sum(losses) / len(losses)))

    date_str=time.strftime('%Y%m%d%H%M%S',time.localtime())
    model_filename='skipgram_{}.pkl'.format(date_str)
    torch.save(net.state_dict(),'saved_model/{}'.format(model_filename))
    print('model is saved as {}'.format(model_filename))
    return model_filename

def load_corpus():
    stop_word = ['', '', ')', '(', '', '', '', '', '', '\n', '', '', ' ', '-', '', '', '.', '\'', '[', ']',
                 '', '/', '.', '"', '\u3000', '', '', ',', '', '?','','','']

    f= open('data/corpus.txt','r',encoding='utf-8')
    text=f.read()
    f.close()
    for i in stop_word:
        text = text.replace(i, "")
    print(text)
    return text

def test():

    text='The quick fox jumps over the lazy dog'
    words=text.split()
    text = load_corpus()
    words=list(cut_sentence(text))
    #print(words)
    w_len = len(words)
    word2id = {words[i]: i for i in range(w_len)}
    one_hot = np.eye(len(words))

    model_filename = train(words)
    net = SkipGram(len(words),20)
    net.load_state_dict(torch.load('saved_model/{}'.format(model_filename)))

    idx=5
    print('word:',words[idx])
    print('prediction:')
    print(net(torch.FloatTensor(one_hot[idx])))

if __name__ == '__main__':
    test()

 

posted @ 2020-06-29 00:57  morein2008  阅读(506)  评论(0编辑  收藏  举报