Scikit-learn 多标签分类 multilabel classification(大量训练数据,MultiOutputClassifier,partial_fit)

核心代码:

# from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.naive_bayes import MultinomialNB
from utils.data_util import load_pickle
import os
from pathConfig import data_dir
from utils.vocab_util import vocab_to_index_dict
import numpy as np

# train & test data
train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")

# train
# classifier = SVC(kernel='linear', probability=True)
# classifier = LogisticRegression()
classifier = MultinomialNB()
print("Training classifier ", str(classifier))
clf = MultiOutputClassifier(classifier, n_jobs=24)

for fname in os.listdir(train_dir):
    fpath = os.path.join(train_dir, fname)
    print("loading file ", fpath)
    train_X, train_y = load_train_file(fpath)
    print("partial_fiting...")
    clf.partial_fit(train_X, train_y, classes=[[0, 1]] * len(label_vocab))
    break

# test
test_X, test_y = load_test_data()

# evaluate for each test file
y_pred = clf.predict_proba(test_X)  # [n_tags, n_test_unit]

y_pred_prcessed = []
for i in range(len(test_X)):
    test_tmp = []
    for j in range(len(tag_vocab)):
        test_tmp.append(y_pred[j][i][0] * 0.5 + y_pred[j][i][1] * 0.5)  # because [0,1]
    y_pred_prcessed.append(np.array(test_tmp))
y_pred_prcessed = np.array(y_pred_prcessed)
posted @ 2020-08-14 18:24  max_xbw  阅读(820)  评论(0编辑  收藏  举报