1 import torch
2 from torch import nn, optim
3 from torch.autograd import Variable
4 import torch.nn.functional as F
5
6 CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right
7 raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ')
8
9 vocab = set(raw_text)
10 word_to_idx = {word: i for i, word in enumerate(vocab)}
11
12 data = []
13 for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):
14 context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
15 target = raw_text[i]
16 data.append((context, target))
17
18
19 class CBOW(nn.Module):
20 def __init__(self, n_word, n_dim, context_size):
21 super(CBOW, self).__init__()
22 self.embedding = nn.Embedding(n_word, n_dim)
23 self.linear1 = nn.Linear(2*context_size*n_dim, 128)
24 self.linear2 = nn.Linear(128, n_word)
25
26 def forward(self, x):
27 x = self.embedding(x)
28 x = x.view(1, -1)
29 x = self.linear1(x)
30 x = F.relu(x, inplace=True)
31 x = self.linear2(x)
32 x = F.log_softmax(x)
33 return x
34
35
36 model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
37 if torch.cuda.is_available():
38 model = model.cuda()
39
40 criterion = nn.CrossEntropyLoss()
41 optimizer = optim.SGD(model.parameters(), lr=1e-3)
42
43 for epoch in range(100):
44 print('epoch {}'.format(epoch))
45 print('*'*10)
46 running_loss = 0
47 for word in data:
48 context, target = word
49 context = Variable(torch.LongTensor([word_to_idx[i] for i in context]))
50 target = Variable(torch.LongTensor([word_to_idx[target]]))
51 if torch.cuda.is_available():
52 context = context.cuda()
53 target = target.cuda()
54 # forward
55 out = model(context)
56 loss = criterion(out, target)
57 running_loss += loss.data[0]
58 # backward
59 optimizer.zero_grad()
60 loss.backward()
61 optimizer.step()
62 print('loss: {:.6f}'.format(running_loss / len(data)))