PyTorch学习笔记之CBOW模型实践


复制代码

 1 import torch
 2 from torch import nn, optim
 3 from torch.autograd import Variable
 4 import torch.nn.functional as F
 5 
 6 CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
 7 raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ')
 8 
 9 vocab = set(raw_text)
10 word_to_idx = {word: i for i, word in enumerate(vocab)}
11 
12 data = []
13 for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):
14     context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
15     target = raw_text[i]
16     data.append((context, target))
17 
18 
19 class CBOW(nn.Module):
20     def __init__(self, n_word, n_dim, context_size):
21         super(CBOW, self).__init__()
22         self.embedding = nn.Embedding(n_word, n_dim)
23         self.linear1 = nn.Linear(2*context_size*n_dim, 128)
24         self.linear2 = nn.Linear(128, n_word)
25 
26     def forward(self, x):
27         x = self.embedding(x)
28         x = x.view(1, -1)
29         x = self.linear1(x)
30         x = F.relu(x, inplace=True)
31         x = self.linear2(x)
32         x = F.log_softmax(x)
33         return x
34 
35 
36 model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
37 if torch.cuda.is_available():
38     model = model.cuda()
39 
40 criterion = nn.CrossEntropyLoss()
41 optimizer = optim.SGD(model.parameters(), lr=1e-3)
42 
43 for epoch in range(100):
44     print('epoch {}'.format(epoch))
45     print('*'*10)
46     running_loss = 0
47     for word in data:
48         context, target = word
49         context = Variable(torch.LongTensor([word_to_idx[i] for i in context]))
50         target = Variable(torch.LongTensor([word_to_idx[target]]))
51         if torch.cuda.is_available():
52             context = context.cuda()
53             target = target.cuda()
54         # forward
55         out = model(context)
56         loss = criterion(out, target)
57         running_loss += loss.data[0]
58         # backward
59         optimizer.zero_grad()
60         loss.backward()
61         optimizer.step()
62     print('loss: {:.6f}'.format(running_loss / len(data)))
posted @ 2019-06-24 14:42  交流_QQ_2240410488  阅读(557)  评论(0编辑  收藏  举报