Pytorch-基于BiLSTM+CRF实现中文分词
CRF:条件随机场,一种机器学习技术。给定一组输入随机变量条件下,另一组输出随机变量的条件概率分布模型。
以一组词性标注为例,给定输入X={我,喜欢,学习},那么输出为Y={名词,动词,名词}的概率应该为最大。输入序列X又称为观测序列,输出序列Y又称为状态序列。这个状态序列构成马尔可夫随机场,所以根据观测序列,得出状态序列的概率就包括,前一个状态转化为后一状态的概率(即转移概率)和状态变量到观测变量的概率(即发射概率)。
CRF分词原理
1. CRF把分词当做字的词位分类问题,通常定义字的词位信息如下:
- 词首,常用B表示;
- 词中,常用M表示;
- 词尾,常用E表示;
- 单子词,常用S表示;
2. CRF分词的过程就是对词位标注后,将B和E之间的字,以及S单字构成分词;
3. CRF分词实例:
- 原始例句:我爱北京天安门
- CRF标注后:我/S 爱/S 北/B 京/E 天/B 安/M 门/E
- 分词结果:我/爱/北京/天安门
语料截图如下:
由于语料很小,下面程序中创建的映射字典也小,所以预测时不能出现字典外的字,否则报KeyError。
链接:https://pan.baidu.com/s/1SUd-QwlD-WlfqGvo7ElhDw
提取码:v0hx
1.config.py
存放一些超参数。
1 filename='word.txt' 2 EMBEDDING_DIM = 5 3 HIDDEN_DIM = 4 4 epochs=100
2.data_process.py
预处理数据
1 import re 2 import torch 3 START_TAG = "<START>" 4 STOP_TAG = "<STOP>" 5 tag_to_ix = {"B": 0, "M": 1, "E": 2,"S":3, START_TAG: 4, STOP_TAG: 5} 6 7 def prepare_sequence(seq, to_ix): #seq是字序列,to_ix是字和序号的字典 8 idxs = [to_ix[w] for w in seq] #idxs是字序列对应的向量 9 return torch.tensor(idxs, dtype=torch.long) 10 11 #将句子转换为字序列 12 def get_word(sentence): 13 word_list = [] 14 sentence = ''.join(sentence.split(' ')) 15 for i in sentence: 16 word_list.append(i) 17 return word_list 18 19 #将句子转换为BMES序列 20 def get_str(sentence): 21 output_str = [] 22 sentence = re.sub(' ', ' ', sentence) #发现有些句子里面,有两格空格在一起 23 list = sentence.split(' ') 24 for i in range(len(list)): 25 if len(list[i]) == 1: 26 output_str.append('S') 27 elif len(list[i]) == 2: 28 output_str.append('B') 29 output_str.append('E') 30 else: 31 M_num = len(list[i]) - 2 32 output_str.append('B') 33 output_str.extend('M'* M_num) 34 output_str.append('E') 35 return output_str 36 37 def read_file(filename): 38 word, content, label = [], [], [] 39 text = open(filename, 'r', encoding='utf-8') 40 for eachline in text: 41 eachline = eachline.strip('\n') 42 eachline = eachline.strip(' ') 43 word_list = get_word(eachline) 44 letter_list = get_str(eachline) 45 word.extend(word_list) 46 content.append(word_list) 47 label.append(letter_list) 48 return word, content, label #word是单列表,content和label是双层列表
查看下数据内容:
1 text, content, label = read_file('word.txt') 2 print(text) 3 print(content) 4 print(label)

1 ['十', '亿', '中', '华', '儿', '女', '踏', '上', '新', '的', '征', '程', '。', '过', '去', '的', '一', '年', ',', '是', '全', '国', '各', '族', '人', '民', '在', '中', '国', '共', '产', '党', '领', '导', '下', ',', '在', '建', '设', '有', '中', '国', '特', '色', '的', '社', '会', '主', '义', '道', '路', '上', ',', '坚', '持', '改', '革', '、', '开', '放', ',', '团', '结', '奋', '斗', '、', '胜', '利', '前', '进', '的', '一', '年', '。', '城', '乡', '经', '济', '体', '制', '改', '革', '向', '纵', '深', '稳', '步', '发', '展', ',', '对', '外', '开', '放', '迈', '出', '了', '新', '的', '步', '伐', ',', '工', '农', '业', '生', '产', '和', '其', '它', '各', '项', '建', '设', '事', '业', '全', '面', '完', '成', '了', '“', '七', '五', '”', '计', '划', '第', '一', '年', '的', '任', '务', ',', '人', '民', '生', '活', '继', '续', '有', '所', '改', '善', '。', '政', '治', '上', '安', '定', '团', '结', ',', '端', '正', '党', '风', '和', '社', '会', '风', '气', '的', '工', '作', '取', '得', '了', '新', '的', '进', '展', ',', '社', '会', '主', '义', '民', '主', '和', '法', '制', '建', '设', '不', '断', '加', '强', '。', '在', '党', '的', '十', '二', '届', '六', '中', '全', '会', '通', '过', '的', '《', '关', '于', '社', '会', '主', '义', '精', '神', '文', '明', '建', '设', '指', '导', '方', '针', '的', '决', '议', '》', '指', '引', '下', ',', '我', '国', '两', '个', '文', '明', '的', '建', '设', '正', '在', '向', '新', '的', '水', '平', '迈', '步', '。', '从', '党', '的', '十', '一', '届', '三', '中', '全', '会', '实', '现', '伟', '大', '历', '史', '转', '折', '到', '现', '在', ',', '我', '国', '政', '治', '安', '定', '团', '结', ',', '经', '济', '稳', '定', '、', '持', '续', '、', '协', '调', '发', '展', '已', '经', '八', '年', '了', ',', '这', '是', '建', '国', '以', '来', '稳', '步', '发', '展', '持', '续', '时', '间', '最', '长', '的', '时', '期', '。', '在', '十', '年', '动', '乱', '之', '后', ',', '取', '得', '这', '样', '一', '个', '大', '好', '局', '面', '是', '不', '容', '易', '的', '。'] 2 [['十', '亿', '中', '华', '儿', '女', '踏', '上', '新', '的', '征', '程', '。'], ['过', '去', '的', '一', '年', ',', '是', '全', '国', '各', '族', '人', '民', '在', '中', '国', '共', '产', '党', '领', '导', '下', ','], ['在', '建', '设', '有', '中', '国', '特', '色', '的', '社', '会', '主', '义', '道', '路', '上', ',', '坚', '持', '改', '革', '、', '开', '放', ',', '团', '结', '奋', '斗', '、', '胜', '利', '前', '进', '的', '一', '年', '。'], ['城', '乡', '经', '济', '体', '制', '改', '革', '向', '纵', '深', '稳', '步', '发', '展', ',', '对', '外', '开', '放', '迈', '出', '了', '新', '的', '步', '伐', ',', '工', '农', '业', '生', '产', '和', '其', '它', '各', '项', '建', '设', '事', '业', '全', '面', '完', '成', '了', '“', '七', '五', '”', '计', '划', '第', '一', '年', '的', '任', '务', ',', '人', '民', '生', '活', '继', '续', '有', '所', '改', '善', '。'], ['政', '治', '上', '安', '定', '团', '结', ',', '端', '正', '党', '风', '和', '社', '会', '风', '气', '的', '工', '作', '取', '得', '了', '新', '的', '进', '展', ',', '社', '会', '主', '义', '民', '主', '和', '法', '制', '建', '设', '不', '断', '加', '强', '。'], ['在', '党', '的', '十', '二', '届', '六', '中', '全', '会', '通', '过', '的', '《', '关', '于', '社', '会', '主', '义', '精', '神', '文', '明', '建', '设', '指', '导', '方', '针', '的', '决', '议', '》', '指', '引', '下', ',', '我', '国', '两', '个', '文', '明', '的', '建', '设', '正', '在', '向', '新', '的', '水', '平', '迈', '步', '。'], ['从', '党', '的', '十', '一', '届', '三', '中', '全', '会', '实', '现', '伟', '大', '历', '史', '转', '折', '到', '现', '在', ',', '我', '国', '政', '治', '安', '定', '团', '结', ',', '经', '济', '稳', '定', '、', '持', '续', '、', '协', '调', '发', '展', '已', '经', '八', '年', '了', ',', '这', '是', '建', '国', '以', '来', '稳', '步', '发', '展', '持', '续', '时', '间', '最', '长', '的', '时', '期', '。'], ['在', '十', '年', '动', '乱', '之', '后', ',', '取', '得', '这', '样', '一', '个', '大', '好', '局', '面', '是', '不', '容', '易', '的', '。']] 3 [['B', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'S'], ['B', 'E', 'S', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'M', 'E', 'B', 'E', 'S', 'S'], ['S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S'], ['B', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'S', 'B', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'M', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S'], ['B', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'S'], ['S', 'S', 'S', 'B', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'B', 'E', 'S'], ['B', 'E', 'S', 'B', 'M', 'M', 'M', 'M', 'M', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'S'], ['S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'S', 'B', 'E', 'S', 'S']]
3.BiLSTM_CRF.py
关于BiLSTM+CRF的详细理解:https://zhuanlan.zhihu.com/p/97676647
转移概率矩阵transitions,transitionsij表示t时刻隐状态为qi,t+1时刻隐状态转换为qj的概率,即P(it+1=qj|it=qi)
1 import torch 2 from data_process import START_TAG,STOP_TAG 3 from torch import nn 4 5 def argmax(vec): #返回每一行最大值的索引 6 _, idx = torch.max(vec, 1) 7 return idx.item() 8 9 10 def prepare_sequence(seq, to_ix): #seq是字序列,to_ix是字和序号的字典 11 idxs = [to_ix[w] for w in seq] #idxs是字序列对应的向量 12 return torch.tensor(idxs, dtype=torch.long) 13 14 15 #LSE函数,模型中经常用到的一种路径运算的实现 16 def log_sum_exp(vec): #vec.shape=[1, target_size] 17 max_score = vec[0, argmax(vec)] #每一行的最大值 18 max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) 19 return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) 20 21 22 class BiLSTM_CRF(nn.Module): 23 24 def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim): 25 super(BiLSTM_CRF, self).__init__() 26 self.embedding_dim = embedding_dim 27 self.hidden_dim = hidden_dim 28 self.vocab_size = vocab_size 29 self.tag_to_ix = tag_to_ix 30 self.tagset_size = len(tag_to_ix) 31 32 self.word_embeds = nn.Embedding(vocab_size, embedding_dim) 33 self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True) 34 35 self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size) # Maps the output of the LSTM into tag space 36 37 self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size)) #随机初始化转移矩阵 38 39 self.transitions.data[tag_to_ix[START_TAG], :] = -10000 #tag_to_ix[START_TAG]: 3(第三行,即其他状态到START_TAG的概率) 40 self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000 #tag_to_ix[STOP_TAG]: 4(第四列,即STOP_TAG到其他状态的概率) 41 self.hidden = self.init_hidden() 42 43 def init_hidden(self): 44 return (torch.randn(2, 1, self.hidden_dim // 2), torch.randn(2, 1, self.hidden_dim // 2)) 45 46 #所有路径的得分,CRF的分母 47 def _forward_alg(self, feats): 48 init_alphas = torch.full((1, self.tagset_size), -10000.) #初始隐状态概率,第1个字是O1的实体标记是qi的概率 49 init_alphas[0][self.tag_to_ix[START_TAG]] = 0. 50 51 forward_var = init_alphas #初始状态的forward_var,随着step t变化 52 53 for feat in feats: #feat的维度是[1, target_size] 54 alphas_t = [] 55 for next_tag in range(self.tagset_size): #给定每一帧的发射分值,按照当前的CRF层参数算出所有可能序列的分值和 56 57 emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size) #发射概率[1, target_size] 隐状态到观测状态的概率 58 trans_score = self.transitions[next_tag].view(1, -1) #转移概率[1, target_size] 隐状态到下一个隐状态的概率 59 next_tag_var = forward_var + trans_score + emit_score #本身应该相乘求解的,因为用log计算,所以改为相加 60 61 alphas_t.append(log_sum_exp(next_tag_var).view(1)) 62 63 forward_var = torch.cat(alphas_t).view(1, -1) 64 65 terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] #最后转到[STOP_TAG],发射分值为0,转移分值为列向量 self.transitions[:, [self.tag2ix[END_TAG]]] 66 return log_sum_exp(terminal_var) 67 68 #得到feats,维度=len(sentence)*tagset_size,表示句子中每个词是分别为target_size个tag的概率 69 def _get_lstm_features(self, sentence): 70 self.hidden = self.init_hidden() 71 embeds = self.word_embeds(sentence).view(len(sentence), 1, -1) 72 lstm_out, self.hidden = self.lstm(embeds, self.hidden) 73 lstm_out = lstm_out.view(len(sentence), self.hidden_dim) 74 lstm_feats = self.hidden2tag(lstm_out) 75 return lstm_feats 76 77 #正确路径的分数,CRF的分子 78 def _score_sentence(self, feats, tags): 79 score = torch.zeros(1) 80 tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags]) 81 for i, feat in enumerate(feats): 82 #self.transitions[tags[i + 1], tags[i]] 是从标签i到标签i+1的转移概率 83 #feat[tags[i+1]], feat是step i的输出结果,有5个值,对应B, I, E, START_TAG, END_TAG, 取对应标签的值 84 score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]] # 沿途累加每一帧的转移和发射 85 score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]] # 加上到END_TAG的转移 86 return score 87 88 89 #解码,得到预测序列的得分,以及预测的序列 90 def _viterbi_decode(self, feats): 91 backpointers = [] #回溯路径;backpointers[i][j]=第i帧到达j状态的所有路径中, 得分最高的那条在i-1帧是什么状态 92 93 # Initialize the viterbi variables in log space 94 init_vvars = torch.full((1, self.tagset_size), -10000.) 95 init_vvars[0][self.tag_to_ix[START_TAG]] = 0 96 97 forward_var = init_vvars 98 for feat in feats: 99 bptrs_t = [] 100 viterbivars_t = [] 101 102 for next_tag in range(self.tagset_size): 103 104 next_tag_var = forward_var + self.transitions[next_tag] #其他标签(B,I,E,Start,End)到标签next_tag的概率 105 best_tag_id = argmax(next_tag_var) #选择概率最大的一条的序号 106 bptrs_t.append(best_tag_id) 107 viterbivars_t.append(next_tag_var[0][best_tag_id].view(1)) 108 109 forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1) #从step0到step(i-1)时5个序列中每个序列的最大score 110 backpointers.append(bptrs_t) 111 112 113 terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] #其他标签到STOP_TAG的转移概率 114 best_tag_id = argmax(terminal_var) 115 path_score = terminal_var[0][best_tag_id] 116 117 best_path = [best_tag_id] 118 for bptrs_t in reversed(backpointers): #从后向前走,找到一个best路径 119 best_tag_id = bptrs_t[best_tag_id] 120 best_path.append(best_tag_id) 121 122 start = best_path.pop() 123 assert start == self.tag_to_ix[START_TAG] #安全性检查 124 best_path.reverse() #把从后向前的路径倒置 125 return path_score, best_path 126 127 #求负对数似然,作为loss 128 def neg_log_likelihood(self, sentence, tags): 129 feats = self._get_lstm_features(sentence) #emission score 130 forward_score = self._forward_alg(feats) #所有路径的分数和,即b 131 gold_score = self._score_sentence(feats, tags) #正确路径的分数,即a 132 return forward_score - gold_score #注意取负号 -log(a/b) = -[log(a) - log(b)] = log(b) - log(a) 133 134 135 def forward(self, sentence): 136 lstm_feats = self._get_lstm_features(sentence) 137 score, tag_seq = self._viterbi_decode(lstm_feats) 138 return score, tag_seq
4.training.py
1 from data_process import read_file, tag_to_ix 2 from config import * 3 from BiLSTM_CRF import * 4 import torch 5 from torch import nn 6 from torch import optim 7 8 _, content, label = read_file(filename) 9 10 def train_data(content, label): 11 train_data = [] 12 for i in range(len(label)): 13 train_data.append((content[i], label[i])) 14 return train_data 15 data = train_data(content,label) 16 17 word_to_ix = {} 18 for sentence, tags in data: 19 for word in sentence: 20 if word not in word_to_ix: 21 word_to_ix[word] = len(word_to_ix) #单词映射,字到序号 22 23 24 model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM) 25 optimizer = optim.Adam(model.parameters(), lr=1e-3) 26 27 #训练 28 for epoch in range(epochs): 29 for sentence, tags in data: 30 model.zero_grad() 31 32 sentence_in = prepare_sequence(sentence, word_to_ix) 33 targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long) 34 loss = model.neg_log_likelihood(sentence_in, targets) 35 36 loss.backward() 37 optimizer.step() 38 if epoch%10 == 0: 39 print('epoch/epochs: {}/{}, loss:{:.6f}'.format(epoch+1, epochs, loss.data[0])) 40 41 #保存模型 42 torch.save(model,'cws.model') 43 torch.save(model.state_dict(),'cws_all.model')
5.test_model.py
调用上面保存的模型,进行预测。
1 from trainning import word_to_ix 2 from data_process import prepare_sequence 3 import torch 4 5 net = torch.load('cws.model') 6 net.eval() 7 stri="改善人民生活水平,建设社会主义政治经济。" 8 precheck_sent = prepare_sequence(stri, word_to_ix) 9 #precheck_sent= tensor([ 45, 102, 23, 24, 80, 98, 140, 141, 17, 32, 33, 37, 38, 39, 40, 103, 104, 60, 61, 12]) 10 11 label = net(precheck_sent)[1] 12 #net(precheck_sent)= (tensor(32.3123, grad_fn=<SelectBackward>), [0, 2, 0, 2, 0, 2, 0, 2, 3, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 2]) 13 14 cws=[] 15 for i in range(len(label)): 16 cws.extend(stri[i]) 17 if label[i]==2 or label==3: 18 cws.append('/') 19 #cws= ['改', '善', '/', '人', '民', '/', '生', '活', '/', '水', '平', '/', ',', '建', '设', '/', '社', '会', '/', '主', '义', '/', '政', '治', '/', '经', '济', '/', '。'] 20 21 print('输入未分词语句:', stri) 22 print('分词结果:', ''.join(cws))
1 epoch/epochs: 1/100, loss:33.839325 2 epoch/epochs: 11/100, loss:31.749798 3 epoch/epochs: 21/100, loss:29.822870 4 epoch/epochs: 31/100, loss:27.391972 5 epoch/epochs: 41/100, loss:26.033567 6 epoch/epochs: 51/100, loss:24.467463 7 epoch/epochs: 61/100, loss:22.403660 8 epoch/epochs: 71/100, loss:20.725002 9 epoch/epochs: 81/100, loss:18.280849 10 epoch/epochs: 91/100, loss:16.049187
输入未分词语句: 改善人民生活水平,建设社会主义政治经济。
分词结果: 改善/人民/生活/水平/,建设/社会/主义/政治/经济/。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构