基于keras采用LSTM实现多标签文本分类
我先抓取博客园知识库的文章标题和分类
代码:
#coding=utf-8 import os import sys import requests from lxml import etree,html import lxml import time import re filepath = 'data/bokeyuan_fenlei.csv' def zhuaqudata(): page = 1 print("开始抓取%s页..." % page) (haslast,titles,fenleis) = getwenzhangandnext(page) for i,title in enumerate(titles): fenlei = fenleis[i] print('[%s] %s' % (fenlei, title)) writefile(filepath, "[%s] %s\n" % (fenlei, title)) print() while haslast: page = page + 1 print("开始抓取%s页..." % page) (haslast,titles,fenleis) = getwenzhangandnext(page) for i,title in enumerate(titles): fenlei = fenleis[i] print('[%s] %s' % (fenlei, title)) writefile(filepath, "[%s] %s\n" % (fenlei, title)) print() def getwenzhangandnext(page): baseurl = 'https://kb.cnblogs.com/' if page == 1: url = baseurl else: url = baseurl + str(page)+'/' print(url) content = geturl(url) htmlcontent = etree.HTML(content) titles = [] fenleis = [] ps = htmlcontent.xpath('//div[@class="list_block"]//div[@class="msg_title"]//p') for p in ps: phtml = html.tostring(p).decode('utf-8') pcontent = etree.HTML(phtml) if not 'span' in phtml: continue else: title = pcontent.xpath('//a//@title')[0] fenlei = pcontent.xpath('//span//text()')[0] titles.append(title) fenleis.append(fenlei) haslasttext = str(htmlcontent.xpath('//div[@id="pager_block"]//div[@id="pager"]//a[last()]//text()')[0]) for i,title in enumerate(titles): titles[i] = formatstr(title) for i,fenlei in enumerate(fenleis): fenleis[i] = formatstr(fenlei) haslast = 0 if 'next' in haslasttext.lower(): haslast = 1 #print("存在下一页") else: #print("不存在下一页") pass time.sleep(3) return haslast,titles,fenleis def formatstr(str): res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5:、?!,]', str) return ''.join(res) def readfile(filepath): fp = open(filepath, 'r', encoding='utf-8') res = fp.read() fp.close() return res def writefile(filepath, s): fp = open(filepath, 'a', encoding='utf-8') fp.write(s) fp.close() def geturl(url): header = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:95.0) Gecko/20100101 Firefox/95.0' } res = requests.get(url,headers=header) res.encoding = res.apparent_encoding return res.text if __name__ == '__main__': zhuaqudata()
结果:
然后通过程序读出文件,建立数据和标签的对应关系,进行编码,建模,训练,测试
代码:
#coding=utf-8 import os import sys import re import jieba from sklearn.preprocessing import MultiLabelBinarizer from keras.preprocessing.text import Tokenizer from keras_preprocessing.sequence import pad_sequences from keras.models import Sequential,Model,load_model import numpy as np from keras.layers import Dense, Input, Flatten, Dropout, LSTM from keras.layers import Conv1D, MaxPooling1D, Embedding, GlobalMaxPooling1D, SpatialDropout1D import random filepath = 'data/bokeyuan_fenlei.csv' stopwordfilepath = 'data/cn_stopwords.txt' def readfile(filepath): fp = open(filepath, 'r', encoding='utf-8') res = fp.read() fp.close() return res def writefile(filepath, s): fp = open(filepath, 'a', encoding='utf-8') fp.write(s) fp.close() def duqushuju(): text = readfile(filepath) stop_text = readfile(stopwordfilepath) stopwords = [i for i in stop_text.split('\n') if i.strip()] res = re.findall('\[(.*?)\](.*?)\n', text) titles = [] fenleis = [] #random.shuffle(res) for i,j in res: fenleis.append([i]) titles.append(contentsplit(j, stopwords)) trainlen = 0#int(len(fenleis) * 0.8) if trainlen > 0: train_data = titles[:trainlen] train_label = fenleis[:trainlen] test_data = titles[trainlen:] test_label = fenleis[trainlen:] else: train_data = titles[:] train_label = fenleis[:] test_data = titles[:] test_label = fenleis[:] all_data = titles all_fenlei = fenleis return all_data,all_fenlei,train_data,train_label,test_data,test_label def contentsplit(segment, stopwords): segment = formatstr(segment) segments = jieba.cut(segment) segments = [i for i in segments if i.strip() and i.strip() not in stopwords and len(i) > 1] seg = " ".join(segments) return seg def formatstr(str): res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5]', str) return ''.join(res) if __name__ == '__main__': all_data,all_fenlei,train_data,train_label,test_data,test_label = duqushuju() print('总分类大小:%s' % len(all_fenlei)) print('总标题大小:%s' % len(all_data)) print('训练分类大小:%s' % len(train_label)) print('训练标题大小:%s' % len(train_data)) print('测试分类大小:%s' % len(test_label)) print('测试标题大小:%s' % len(test_data)) train_dict = {} for i,j in enumerate(train_label): train_dict[i] = j # 标签向量化 mutil_lab = MultiLabelBinarizer() train_label_code = mutil_lab.fit_transform(train_label) mutil_lab = MultiLabelBinarizer() test_label_code = mutil_lab.fit_transform(test_label) tokenizer = Tokenizer(num_words=40000, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True) tokenizer.fit_on_texts(train_data) #print(tokenizer.word_index) # 利用Tokenizer 向量化文本 x_data = tokenizer.texts_to_sequences(train_data) x_data = pad_sequences(x_data, 100) y_data = np.array(train_label_code) # 利用Tokenizer 向量化文本 x_test_data = tokenizer.texts_to_sequences(test_data) x_test_data = pad_sequences(x_test_data, 100) y_test_data = np.array(test_label_code) print("训练集的大小为: ", x_data.shape, "训练集标签的大小为: ", y_data.shape) print("测试集的大小为: ", x_test_data.shape, "测试集标签的大小为: ", y_test_data.shape) model_path = 'models/wenben_fenlei_lstm.h5' if os.path.exists(model_path): model = load_model(model_path) else: # 构建模型 inputs = Input(shape=(100,)) embed = Embedding(40000, 100, input_length=x_data.shape[1])(inputs) dropout = SpatialDropout1D(0.2)(embed) # 注意LSTM层的参数是为了能够用上cuDNN的加速 lstm = LSTM(100, dropout=0.2, recurrent_dropout=0, activation='tanh', recurrent_activation='sigmoid')(dropout) output = Dense(y_data.shape[1], activation='sigmoid')(lstm) model = Model(inputs, output) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary()# 评估模型 model.fit(x_data, y_data, batch_size=16, epochs=20, validation_data=(x_test_data, y_test_data)) model.save(model_path) n = 3 pre = model.predict(x_data[:n], n) for i in range(n): print('[%s] %s' % (','.join(train_label[i]), train_data[i])) print('预测值为:%s' % ','.join(train_dict[pre[i].argmax()])) print() ceshi_data = ['FWT/快速沃尔什变换 入门指南', '如何在 Apinto 实现 HTTP 与gRPC 的协议转换 (下)', '万字血书Vue—Vue语法', '云图说丨初识华为云安全云脑——新一代云安全运营中心'] # 利用Tokenizer 向量化文本 x_ceshi_data = tokenizer.texts_to_sequences(ceshi_data) x_ceshi_data = pad_sequences(x_ceshi_data, 100) n = 4 pre = model.predict(x_ceshi_data[:n], n) for i in range(n): print('%s' % ceshi_data[i]) print('预测值为:%s' % ','.join(train_dict[pre[i].argmax()])) print()
停词的data/cn_stopwords.txt 你可以随便创建一个,空的也没有问题,只是会影响到切词准确与否的问题
我先对训练库的前三个标题做了预测,基本正确,后对4个博客文章的标题做了预测,至少是出结果了。
效果:
参考:https://blog.csdn.net/qq_56154355/article/details/125685955
本文来自博客园,作者:河北大学-徐小波,转载请注明原文链接:https://www.cnblogs.com/xuxiaobo/p/17227240.html