

  1. 概述
  2. 数据集合
  3. 代码
  4. 结果展示





训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04




3.1 数据采集cnews_loader.py


    1     # coding: utf-8
    2     import sys
    3     from collections import Counter
    4     import numpy as np
    5     import tensorflow.contrib.keras as kr
    7     if sys.version_info[0] > 2:
    8         is_py3 = True
    9     else:
   10         reload(sys)
   11         sys.setdefaultencoding("utf-8")
   12         is_py3 = False
   14     def native_word(word, encoding='utf-8'):
   15         """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
   16         if not is_py3:
   17             return word.encode(encoding)
   18         else:
   19             return word
   21     def native_content(content):
   22         if not is_py3:
   23             return content.decode('utf-8')
   24         else:
   25             return content
   27     def open_file(filename, mode='r'):
   28         """
   29         常用文件操作,可在python2和python3间切换.
   30         mode: 'r' or 'w' for read or write
   31         """
   32         if is_py3:
   33             return open(filename, mode, encoding='utf-8', errors='ignore')
   34         else:
   35             return open(filename, mode)
   37     def read_file(filename):
   38         """读取文件数据"""
   39         contents, labels = [], []
   40         with open_file(filename) as f:
   41             for line in f:
   42                 try:
   43                     label, content = line.strip().split('\t')
   44                     if content:
   45                         contents.append(list(native_content(content)))
   46                         labels.append(native_content(label))
   47                 except:
   48                     pass
   49         return contents, labels
   51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
   52         """根据训练集构建词汇表,存储"""
   53         data_train, _ = read_file(train_dir)
   54         all_data = []
   55         for content in data_train:
   56             all_data.extend(content)
   57         counter = Counter(all_data)
   58         count_pairs = counter.most_common(vocab_size - 1)
   59         words, _ = list(zip(*count_pairs))
   60         # 添加一个 <PAD> 来将所有文本pad为同一长度
   61         words = ['<PAD>'] + list(words)
   62         open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
   64     def read_vocab(vocab_dir):
   65         """读取词汇表"""
   66         # words = open_file(vocab_dir).read().strip().split('\n')
   67         with open_file(vocab_dir) as fp:
   68             # 如果是py2 则每个值都转化为unicode
   69             words = [native_content(_.strip()) for _ in fp.readlines()]
   70         word_to_id = dict(zip(words, range(len(words))))
   71         return words, word_to_id
   73     def read_category():
   74         """读取分类目录,固定"""
   75         categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
   76         categories = [native_content(x) for x in categories]
   77         cat_to_id = dict(zip(categories, range(len(categories))))
   78         return categories, cat_to_id
   80     def to_words(content, words):
   81         """将id表示的内容转换为文字"""
   82         return ''.join(words[x] for x in content)
   84     def process_file(filename, word_to_id, cat_to_id, max_length=600):
   85         """将文件转换为id表示"""
   86         contents, labels = read_file(filename)
   88         data_id, label_id = [], []
   89         for i in range(len(contents)):
   90             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
   91             label_id.append(cat_to_id[labels[i]])
   92         # 使用keras提供的pad_sequences来将文本pad为固定长度
   93         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
   94         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示
   95         return x_pad, y_pad
   97     def batch_iter(x, y, batch_size=64):
   98         """生成批次数据"""
   99         data_len = len(x)
  100         num_batch = int((data_len - 1) / batch_size) + 1
  101         indices = np.random.permutation(np.arange(data_len))
  102         x_shuffle = x[indices]
  103         y_shuffle = y[indices]
  105         for i in range(num_batch):
  106             start_id = i * batch_size
  107             end_id = min((i + 1) * batch_size, data_len)
  108             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]


3.2 模型搭建cnn_model.py


    1     # coding: utf-8
    2     import tensorflow as tf
    3     class TCNNConfig(object):
    4         """CNN配置参数"""
    5         embedding_dim = 64  # 词向量维度
    6         seq_length = 600  # 序列长度
    7         num_classes = 10  # 类别数
    8         num_filters = 256  # 卷积核数目
    9         kernel_size = 5  # 卷积核尺寸
   10         vocab_size = 5000  # 词汇表达小
   11         hidden_dim = 128  # 全连接层神经元
   12         dropout_keep_prob = 0.5  # dropout保留比例
   13         learning_rate = 1e-3  # 学习率
   14         batch_size = 64  # 每批训练大小
   15         num_epochs = 10  # 总迭代轮次
   16         print_per_batch = 100  # 每多少轮输出一次结果
   17         save_per_batch = 10  # 每多少轮存入tensorboard
   19     class TextCNN(object):
   20         """文本分类,CNN模型"""
   21         def __init__(self, config):
   22             self.config = config
   23             # 三个待输入的数据
   24             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
   25             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
   26             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
   27             self.cnn()
   29         def cnn(self):
   30             """CNN模型"""
   31             # 词向量映射
   32             with tf.device('/cpu:0'):
   33                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
   34                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
   36             with tf.name_scope("cnn"):
   37                 # CNN layer
   38                 conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
   39                 # global max pooling layer
   40                 gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
   42             with tf.name_scope("score"):
   43                 # 全连接层,后面接dropout以及relu激活
   44                 fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
   45                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
   46                 fc = tf.nn.relu(fc)
   47                 # 分类器
   48                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
   49                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
   51             with tf.name_scope("optimize"):
   52                 # 损失函数,交叉熵
   53                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
   54                 self.loss = tf.reduce_mean(cross_entropy)
   55                 # 优化器
   56                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
   58             with tf.name_scope("accuracy"):
   59                 # 准确率
   60                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
   61                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

3.3 运行代码run_cnn.py

  1 #!/usr/bin/python
  2 # -*- coding: utf-8 -*-
  3 from __future__ import print_function
  4 import os
  5 import sys
  6 import time
  7 from datetime import timedelta
  8 import numpy as np
  9 import tensorflow as tf
 10 from sklearn import metrics
 11 from cnn_model import TCNNConfig, TextCNN
 12 from  cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
 14 base_dir = '../data/cnews'
 15 train_dir = os.path.join(base_dir, 'cnews.train.txt')
 16 test_dir = os.path.join(base_dir, 'cnews.test.txt')
 17 val_dir = os.path.join(base_dir, 'cnews.val.txt')
 18 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 19 save_dir = 'checkpoints/textcnn'
 20 save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
 22 def get_time_dif(start_time):
 23     """获取已使用时间"""
 24     end_time = time.time()
 25     time_dif = end_time - start_time
 26     return timedelta(seconds=int(round(time_dif)))
 28 def feed_data(x_batch, y_batch, keep_prob):
 29     feed_dict = {
 30         model.input_x: x_batch,
 31         model.input_y: y_batch,
 32         model.keep_prob: keep_prob
 33     }
 34     return feed_dict
 36 def evaluate(sess, x_, y_):
 37     """评估在某一数据上的准确率和损失"""
 38     data_len = len(x_)
 39     batch_eval = batch_iter(x_, y_, 128)
 40     total_loss = 0.0
 41     total_acc = 0.0
 42     for x_batch, y_batch in batch_eval:
 43         batch_len = len(x_batch)
 44         feed_dict = feed_data(x_batch, y_batch, 1.0)
 45         loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
 46         total_loss += loss * batch_len
 47         total_acc += acc * batch_len
 48     return total_loss / data_len, total_acc / data_len
 50 def train():
 51     print("Configuring TensorBoard and Saver...")
 52     # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
 53     tensorboard_dir = '../tensorboard/textcnn'
 54     if not os.path.exists(tensorboard_dir):
 55         os.makedirs(tensorboard_dir)
 56     tf.summary.scalar("loss", model.loss)
 57     tf.summary.scalar("accuracy", model.acc)
 58     merged_summary = tf.summary.merge_all()
 59     writer = tf.summary.FileWriter(tensorboard_dir)
 61     # 配置 Saver
 62     saver = tf.train.Saver()
 63     if not os.path.exists(save_dir):
 64         os.makedirs(save_dir)
 66     print("Loading training and validation data...")
 67     # 载入训练集与验证集
 68     start_time = time.time()
 69     x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
 70     x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
 71     time_dif = get_time_dif(start_time)
 72     print("Time usage:", time_dif)
 74     # 创建session
 75     session = tf.Session()
 76     session.run(tf.global_variables_initializer())
 77     writer.add_graph(session.graph)
 79     print('Training and evaluating...')
 80     start_time = time.time()
 81     total_batch = 0  # 总批次
 82     best_acc_val = 0.0  # 最佳验证集准确率
 83     last_improved = 0  # 记录上一次提升批次
 84     require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
 86     flag = False
 87     for epoch in range(config.num_epochs):
 88         print('Epoch:', epoch + 1)
 89         batch_train = batch_iter(x_train, y_train, config.batch_size)
 90         for x_batch, y_batch in batch_train:
 91             feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
 92             #print("x_batch is {}".format(x_batch.shape))
 93             if total_batch % config.save_per_batch == 0:
 94                 # 每多少轮次将训练结果写入tensorboard scalar
 95                 s = session.run(merged_summary, feed_dict=feed_dict)
 96                 writer.add_summary(s, total_batch)
 97             if total_batch % config.print_per_batch == 0:
 98                 # 每多少轮次输出在训练集和验证集上的性能
 99                 feed_dict[model.keep_prob] = 1.0
100                 loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
101                 loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
102                 if acc_val > best_acc_val:
103                     # 保存最好结果
104                     best_acc_val = acc_val
105                     last_improved = total_batch
106                     saver.save(sess=session, save_path=save_path)
107                     improved_str = '*'
108                 else:
109                     improved_str = ''
110                 time_dif = get_time_dif(start_time)
111                 msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
112                       + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
113                 print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
115             session.run(model.optim, feed_dict=feed_dict)  # 运行优化
116             total_batch += 1
118             if total_batch - last_improved > require_improvement:
119                 # 验证集正确率长期不提升,提前结束训练
120                 print("No optimization for a long time, auto-stopping...")
121                 flag = True
122                 break  # 跳出循环
123         if flag:  # 同上
124             break
126 def test():
127     print("Loading test data...")
128     start_time = time.time()
129     x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
131     session = tf.Session()
132     session.run(tf.global_variables_initializer())
133     saver = tf.train.Saver()
134     saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
136     print('Testing...')
137     loss_test, acc_test = evaluate(session, x_test, y_test)
138     msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
139     print(msg.format(loss_test, acc_test))
141     batch_size = 128
142     data_len = len(x_test)
143     num_batch = int((data_len - 1) / batch_size) + 1
145     y_test_cls = np.argmax(y_test, 1)
146     y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
147     for i in range(num_batch):  # 逐批次处理
148         start_id = i * batch_size
149         end_id = min((i + 1) * batch_size, data_len)
150         feed_dict = {
151             model.input_x: x_test[start_id:end_id],
152             model.keep_prob: 1.0
153         }
154         y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
156     # 评估
157     print("Precision, Recall and F1-Score...")
158     print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
160     # 混淆矩阵
161     print("Confusion Matrix...")
162     cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
163     print(cm)
165     time_dif = get_time_dif(start_time)
166     print("Time usage:" , time_dif)
168 if __name__ == '__main__':
170     config = TCNNConfig()
171     if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建,这里存在,因此不用重建
172         build_vocab(train_dir, vocab_dir, config.vocab_size)
173     categories, cat_to_id = read_category()
174     words, word_to_id = read_vocab(vocab_dir)
175     config.vocab_size = len(words)
176     model = TextCNN(config)
177     option='train'
178     if option == 'train':
179         train()
180     else:
181         test()

3.4 预测predict.py

    1     # coding: utf-8
    2     from __future__ import print_function
    3     import os
    4     import tensorflow as tf
    5     import tensorflow.contrib.keras as kr
    6     from cnn_model import TCNNConfig, TextCNN
    7     from cnews_loader import read_category, read_vocab
    8     try:
    9         bool(type(unicode))
   10     except NameError:
   11         unicode = str
   13     base_dir = '../data/cnews'
   14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
   15     save_dir = '../checkpoints/textcnn'
   16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
   18     class CnnModel:
   19         def __init__(self):
   20             self.config = TCNNConfig()
   21             self.categories, self.cat_to_id = read_category()
   22             self.words, self.word_to_id = read_vocab(vocab_dir)
   23             self.config.vocab_size = len(self.words)
   24             self.model = TextCNN(self.config)
   25             self.session = tf.Session()
   26             self.session.run(tf.global_variables_initializer())
   27             saver = tf.train.Saver()
   28             saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
   30         def predict(self, message):
   31             # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
   32             content = unicode(message)
   33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
   35             feed_dict = {
   36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
   37                 self.model.keep_prob: 1.0
   38             }
   40             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
   41             return self.categories[y_pred_cls[0]]
   43     if __name__ == '__main__':
   44         cnn_model = CnnModel()
   45         test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
   46                      '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']
   47         for i in test_demo:
   48             print(cnn_model.predict(i))






