pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》

论文  《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。

论文地址: 666666

模型图:

  

 

 模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial

 1 # -*- coding: utf-8 -*-
 2 # @time : 2019/11/9  13:55
 3 
 4 import numpy as np
 5 import torch
 6 import torch.nn as nn
 7 import torch.optim as optim
 8 from torch.autograd import Variable
 9 import torch.nn.functional as F
10 
11 dtype = torch.FloatTensor
12 
13 # Text-CNN Parameter
14 embedding_size = 2 # n-gram
15 sequence_length = 3
16 num_classes = 2  # 0 or 1
17 filter_sizes = [2, 2, 2] # n-gram window
18 num_filters = 3
19 
20 # 3 words sentences (=sequence_length is 3)
21 sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
22 labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.
23 
24 word_list = " ".join(sentences).split()
25 word_list = list(set(word_list))
26 word_dict = {w: i for i, w in enumerate(word_list)}
27 vocab_size = len(word_dict)
28 
29 inputs = []
30 for sen in sentences:
31     inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
32 
33 targets = []
34 for out in labels:
35     targets.append(out) # To using Torch Softmax Loss function
36 
37 input_batch = Variable(torch.LongTensor(inputs))
38 target_batch = Variable(torch.LongTensor(targets))
39 
40 
41 class TextCNN(nn.Module):
42     def __init__(self):
43         super(TextCNN, self).__init__()
44 
45         self.num_filters_total = num_filters * len(filter_sizes)
46         self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
47         self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
48         self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)
49 
50     def forward(self, X):
51         embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
52         embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
53 
54         pooled_outputs = []
55         for filter_size in filter_sizes:
56             # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
57             conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
58             h = F.relu(conv)
59             # mp : ((filter_height, filter_width))
60             mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
61             # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
62             pooled = mp(h).permute(0, 3, 2, 1)
63             pooled_outputs.append(pooled)
64 
65         h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
66         h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]
67 
68         model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
69         return model
70 
71 model = TextCNN()
72 
73 criterion = nn.CrossEntropyLoss()
74 optimizer = optim.Adam(model.parameters(), lr=0.001)
75 
76 # Training
77 for epoch in range(5000):
78     optimizer.zero_grad()
79     output = model(input_batch)
80 
81     # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
82     loss = criterion(output, target_batch)
83     if (epoch + 1) % 1000 == 0:
84         print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
85 
86     loss.backward()
87     optimizer.step()
88 
89 # Test
90 test_text = 'sorry hate you'
91 tests = [np.asarray([word_dict[n] for n in test_text.split()])]
92 test_batch = Variable(torch.LongTensor(tests))
93 
94 # Predict
95 predict = model(test_batch).data.max(1, keepdim=True)[1]
96 if predict[0][0] == 0:
97     print(test_text,"is Bad Mean...")
98 else:
99     print(test_text,"is Good Mean!!")

 

posted @ 2019-11-09 15:13  _Meditation  阅读(1622)  评论(0编辑  收藏  举报