PyTorch LSTM的一个简单例子:实现单词词性判断

      本文将使用LSTM来判别一句话中每一个单词的词性。在一句话中,如果我们孤立地看某一个单词,比如单词book,而不看book前面的单词,就不能准确的判断book在这句话中是动词还是名词,但如果我们能记住book前面出现的单词,那么就能很有把握地判断book的词性。LSTM神经网络就能记住前面的单词。关于LSTM的详细介绍,大家可参考文末的参考资料[1][2]。

      下面的代码主要来自文末的参考资料[3],本文对原代码做了修改并增加了注释,使其变得更简单易懂。要理解下面的程序,理解torch.nn.Embedding是关键之一,这篇博客将提供帮助。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
'''
本程序实现了对单词词性的判断,输入一句话,输出该句话中每个单词的词性。
'''
 
import torch
import torch.nn.functional as F
from torch import nn, optim
 
training_data = [("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])]
 
word_to_idx = {}
tag_to_idx = {}
for context, tag in training_data:
    for word in context:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)
    for label in tag:
        if label not in tag_to_idx:
            tag_to_idx[label] = len(tag_to_idx)
idx_to_tag = {tag_to_idx[tag]: tag for tag in tag_to_idx}
 
 
class LSTMTagger(nn.Module):
    def __init__(self, n_word, n_dim, n_hidden, n_tag):
        super(LSTMTagger, self).__init__()
        self.word_embedding = nn.Embedding(n_word, n_dim)
        self.lstm = nn.LSTM(n_dim, n_hidden, batch_first=True# nn.lstm()接受的数据输入是(序列长度,batch,输入维数),
                                                                # 这和我们cnn输入的方式不太一致,所以使用batch_first=True,把输入变成(batch,序列长度,输入维度),本程序的序列长度指的是一句话的单词数目
                                                                # 同时,batch_first=True会改变输出的维度顺序。<br>
        self.linear1 = nn.Linear(n_hidden, n_tag)
 
    def forward(self, x):            # x是word_list,即单词的索引列表,size为len(x)
        x = self.word_embedding(x)   # embedding之后,x的size为(len(x),n_dim)
        x = x.unsqueeze(0)           # unsqueeze之后,x的size为(1,len(x),n_dim),1在下一行程序的lstm中被当做是batchsize,len(x)被当做序列长度
        x, _ = self.lstm(x)          # lstm的隐藏层输出,x的size为(1,len(x),n_hidden),因为定义lstm网络时用了batch_first=True,所以1在第一维,如果batch_first=False,则len(x)会在第一维
        x = x.squeeze(0)             # squeeze之后,x的size为(len(x),n_hidden),在下一行的linear层中,len(x)被当做是batchsize
        x = self.linear1(x)          # linear层之后,x的size为(len(x),n_tag)
        y = F.log_softmax(x, dim=1# 对第1维先进行softmax计算,然后log一下。y的size为(len(x),n_tag)。
        return y
 
 
model = LSTMTagger(len(word_to_idx), 100, 128, len(tag_to_idx))
if torch.cuda.is_available():
    model = model.cuda()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
 
 
for epoch in range(200):
    running_loss = 0
    for data in training_data:
        sentence, tags = data
        word_list = [word_to_idx[word] for word in sentence]     # word_list是word索引列表
        word_list = torch.LongTensor(word_list)
        tag_list = [tag_to_idx[tag] for tag in tags]             # tag_list是tag索引列表
        tag_list = torch.LongTensor(tag_list)
        if torch.cuda.is_available():
            word_list = word_list.cuda()
            tag_list = tag_list.cuda()
        # forward
        out = model(word_list)
        loss = criterion(out, tag_list)
        running_loss += loss.data.numpy()
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch: {:<3d} | Loss: {:6.4f}'.format(epoch, running_loss / len(data)))
 
# 模型测试
test_sentence = "Everybody ate the apple"
print('\n The test sentence is:\n', test_sentence)
test_sentence = test_sentence.split()
test_list = [word_to_idx[word] for word in test_sentence]
test_list = torch.LongTensor(test_list)
if torch.cuda.is_available():
    test_list = test_list.cuda()
 
out = model(test_list)
_, predict_idx = torch.max(out, 1# 1表示找行的最大值。 predict_idx是词性索引,是一个size为([len(test_sentence)]的张量
predict_tag = [idx_to_tag[idx] for idx in list(predict_idx.numpy())]
print('The predict tags are:', predict_tag)

 

参考资料:

[1] 零基础入门深度学习(6) - 长短时记忆网络(LSTM)

[2] 10分钟快速入门PyTorch (5)

[3] 10分钟快速入门PyTorch (9)

posted @   Picassooo  阅读(2510)  评论(3编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示