开发记录3
上次完成了关键字的提取,这一次就实现自动分类
在实现自动分类的时候,我在晚上找了很多关于自动分类的方法,找了关于spark,关于python的,java的等等都比较乱
然后我又在网上找了基于python的机器学习,可以自动对内容进行自动分类,代码如下:
#!/usr/bin/env python # coding=utf-8 import sys import jieba from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.svm import LinearSVC from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import MultiLabelBinarizer import pymysql import pandas as pd import re import numpy as np def jieba_tokenizer(x): return jieba.cut(x, cut_all=True) def partition(x): return x def filter_html(s): d = re.compile(r'<[^>]+>', re.S) s = d.sub('', s) return s def gbk_utf8(s): s = s.decode('gbk', "ignore").encode('utf8') return s def write_sql(id,classs): db = pymysql.Connection(host="localhost", port=3306, user="root", password="root", database="dazuoye", charset="utf8") cursor = db.cursor() sql = "update info_tech set type='" + classs + "' where index=" + str(id) try: cursor.execute(sql) db.commit() except: db.commit() print("出错了!") db.close() # 链接mysql数据库 conn = pymysql.Connection(host="localhost",port=3306,user="root", password="root",database="dazuoye",charset="utf8") cursor = conn.cursor() cursor=conn.cursor() # 训练数据样本 data_ret = pd.DataFrame() sql = "SELECT index, title3,type,content FROM info_tech " # print sql cursor.execute(sql) txt_ret = [] #class_ret = [["信息化"],["大数据"],["云计算"],["区块链"],["智慧城市"],["工业互联网"],["信息安全"],["操作系统"],["计算机"],["法律法规"],["信息化战略"]] class_ret=[] id_ret = [] for row in cursor.fetchall(): content = filter_html(row[3]) txt_ret.append(content) class_s = row[2] class_l = class_s.split(" ") class_ret.append(class_l) id_ret.append(row[0]) txt_ret = txt_ret X_train = txt_ret print(class_ret) Y_train = class_ret classifier = Pipeline([ ('counter', CountVectorizer(tokenizer=jieba_tokenizer)), ('tfidf', TfidfTransformer()), ('clf', OneVsRestClassifier(LinearSVC())), ]) mlb = MultiLabelBinarizer() Y_train = mlb.fit_transform(Y_train) classifier.fit(X_train, Y_train) print(classifier.score(X_train,Y_train)) # 测试数据 test_txt_set = [] sql = "SELECT index, title3,keyword,content FROM info_tech " cursor.execute(sql) test_id_ret = [] for row in cursor.fetchall(): test_txt_set.append(filter_html(row[3])) test_id_ret.append(row[0]) X_test = test_txt_set prediction = classifier.predict(X_test) result = mlb.inverse_transform(prediction) # 展示结果 for i, label1 in enumerate(result): classstr = '' for j, label2 in enumerate(label1): classstr += str(label2) + "" print("ID:" + str(test_id_ret[i]) + " =>class:" + classstr) write_sql(test_id_ret[i],classstr)
参考教程:https://morvanzhou.github.io/tutorials/machine-learning/sklearn/
曾请教:王莉