Text Classification with Keras

Text Classification with Keras

Data

采用清华文本分类数据集( http://thuctc.thunlp.org/message ),经过采样率为1%的采样后,得到小样本数据集。

Code

# encoding=utf-8

import argparse
import logging
from tensorflow import keras
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split


logger = logging.getLogger(__name__)


def parse_argv():
    parser = argparse.ArgumentParser()
    parser.add_argument("corpus", type=str, help="file path of corpus")
    parser.add_argument("--tr", type=float, default=0.2, help="ratio of test data")
    parser.add_argument("--max_len", type=int, default=100, help="max_len")
    parser.add_argument("--embed_dim", type=int, default=50, help="max_len")
    parser.add_argument("--epochs", type=int, default=3, help="epoch of training")
    parser.add_argument("--batch_size", type=int, default=32, help="batch size of training")
    parser.add_argument("--log", type=str, default=None, help="log file path")
    return parser.parse_args()


def load_corpus(path):
    x, y = list(), list()
    f = open(path, encoding='utf-8')
    for line in f:
        d = eval(line.strip())
        x.append(d['text']), y.append(d['label'])
    return x, y


def word_to_index(sentences, vocab=None):
    d = dict()
    m = list()
    for sentence in sentences:
        sentence = sentence.split(' ')
        l = list()
        for word in sentence:
            if word not in d:
                d[word] = len(d)
            l.append(d[word])
        m.append(l)
    return m, d


def label_to_index(labels):
    d = dict()
    l = list()
    for label in labels:
        if label not in d:
            d[label] = len(d)
        l.append(d[label])
    return l, d


def run(argv):
    x, y = load_corpus(argv.corpus)
    x, vocab = word_to_index(x)
    y, label = label_to_index(y)
    x_train, x_test, y_train, y_test = train_test_split(
        x, y,
        test_size=argv.tr,
        random_state=33
    )
    x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=argv.max_len)
    x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=argv.max_len)
    y_train = keras.utils.to_categorical(y_train, num_classes=len(label))

    model = keras.models.Sequential()
    model.add(keras.layers.Embedding(len(vocab), argv.embed_dim))
    model.add(keras.layers.Conv1D(32, 7, activation='relu'))
    model.add(keras.layers.MaxPooling1D(3))
    model.add(
        keras.layers.Bidirectional(
            keras.layers.LSTM(
                return_sequences=True,
                units=argv.embed_dim
            )
        )
    )
    model.add(
        keras.layers.Bidirectional(
            keras.layers.LSTM(units=argv.embed_dim)
        )
    )
    model.add(keras.layers.Dense(len(label), activation='softmax'))
    model.compile(
        optimizer='rmsprop',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    model.summary()
    model.fit(
        x=x_train,
        y=y_train,
        epochs=argv.epochs,
        verbose=1,
        batch_size=argv.batch_size
    )
    y_pred = model.predict_classes(x_test)
    print(classification_report(y_test, y_pred, target_names=label.keys()))


def main():
    argv = parse_argv()
    logging.basicConfig(
        filename=argv.log,
        format='%(asctime)s - %(levelname)s - %(name)s - %(msg)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO)
    run(argv)


if __name__ == '__main__':
    main()

Output

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 50)          7725100   
_________________________________________________________________
conv1d (Conv1D)              (None, None, 32)          11232     
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, None, 32)          0         
_________________________________________________________________
bidirectional (Bidirectional (None, None, 100)         33200     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 100)               60400     
_________________________________________________________________
dense (Dense)                (None, 14)                1414      
=================================================================
Total params: 7,831,346
Trainable params: 7,831,346
Non-trainable params: 0
_________________________________________________________________

Epoch 1/3
209/209 [==============================] - 15s 64ms/step - loss: 2.1679 - accuracy: 0.2324
Epoch 2/3
209/209 [==============================] - 13s 64ms/step - loss: 1.2895 - accuracy: 0.5615
Epoch 3/3
209/209 [==============================] - 13s 64ms/step - loss: 0.6497 - accuracy: 0.8001

              precision    recall  f1-score   support

          时尚       0.24      0.33      0.28        18
          家居       0.35      0.44      0.39        70
          教育       0.79      0.52      0.63        88
          股票       0.81      0.92      0.86       313
          娱乐       0.78      0.57      0.66       187
          彩票       0.00      0.00      0.00        20
          社会       0.44      0.65      0.52        94
          房产       0.44      0.51      0.47        45
          星座       0.00      0.00      0.00         6
          科技       0.80      0.75      0.77       320
          财经       0.40      0.48      0.44        83
          时政       0.52      0.61      0.56       116
          游戏       0.73      0.58      0.65        52
          体育       0.89      0.82      0.86       259

    accuracy                           0.69      1671
   macro avg       0.51      0.51      0.51      1671
weighted avg       0.70      0.69      0.69      1671

0.6918013165769

Analysis

现象1:对于样本数量较少的类,比如星座(6)、时尚(18)、彩票(20),分类性能较差。对于样本数量较多的类,比如科技(320)、股票(313)、体育(259),分类性能较好。

结论1:样本数量对模型性能影响较大。

现象2:模型在训练集上,分类准确率达到80.01%;模型在测试集上,分类准确率仅为69.18%。

结论2:模型在训练集上过拟合了。

posted @ 2021-08-03 14:18  健康平安快乐  阅读(71)  评论(0编辑  收藏  举报