基于keras的fasttext短文本分类
### train_model.py ###
#!/usr/bin/env python # coding=utf-8 import codecs import simplejson as json import numpy as np import pandas as pd from keras.models import Sequential, load_model from keras.callbacks import EarlyStopping, ModelCheckpoint from keras.preprocessing import sequence from keras.utils import to_categorical from keras.layers import * from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from sklearn.externals import joblib import logging import re import pickle as pkl logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s: %(message)s', datefmt='%Y-%m-%d %H:%M', filename='log/train_model.log', filemode='a+') ngram_range = 1 max_features = 6500 maxlen = 120 fw = open('error_line_test.txt', 'wb') DIRTY_LABEL = re.compile('\W+') # set([u'业务',u'代销',u'施工',u'策划',u'设计',u'销售',u'除外',u'零售',u'食品']) STOP_WORDS = pkl.load(open('./data/stopwords.pkl')) def load_data(fname='data/12315_industry_business_train.csv', nrows=None): """ 载入训练数据 """ data, labels = [], [] char2idx = json.load(open('data/char2idx.json')) used_keys = set(['name', 'business']) df = pd.read_csv(fname, encoding='utf-8', nrows=nrows) for idx, item in df.iterrows(): item = item.to_dict() line = '' for key, value in item.iteritems(): if key in used_keys: line += key+value data.append([char2idx[char] for char in line if char in char2idx]) labels.append(item['label']) le = LabelEncoder() logging.info('%d nb_class: %s' % (len(np.unique(labels)), str(np.unique(labels)))) onehot_label = to_categorical(le.fit_transform(labels)) joblib.dump(le, 'model/tgind_labelencoder.h5') x_train, x_test, y_train, y_test = train_test_split(data, onehot_label, test_size=0.1) return (x_train, y_train), (x_test, y_test) def create_ngram_set(input_list, ngram_value=2): return set(zip(*[input_list[i:] for i in range(ngram_value)])) def add_ngram(sequences, token_indice, ngram_range=2): """ Augment the input list of sequences by appending n-grams values """ new_sequences = [] for input_list in sequences: new_list = input_list[:] for i in range(len(new_list) - ngram_range + 1): for ngram_value in range(2, ngram_range+1): ngram = tuple(new_list[i:i+ngram_value]) if ngram in token_indice: new_list.append(token_indice[ngram]) new_sequences.append(new_list) return new_sequences (x_train, y_train), (x_test, y_test) = load_data() nb_class = y_train.shape[1] logging.info('x_train size: %d' % (len(x_train))) logging.info('x_test size: %d' % (len(x_test))) logging.info('x_train sent average len: %.2f' % (np.mean(list(map(len, x_train))))) print 'x_train sent avg length: %.2f' % (np.mean(list(map(len, x_train)))) if ngram_range>1: print 'add {}-gram features'.format(ngram_range) ngram_set = set() for input_list in x_train: for i in range(2, ngram_range+1): set_of_ngram = create_ngram_set(input_list, ngram_value=i) ngram_set.update(set_of_ngram) start_index = max_features + 1 token_indice = {v: k+start_index for k,v in enumerate(ngram_set)} indice_token = {token_indice[k]: k for k in token_indice} max_features = np.max(list(indice_token.keys()))+1 x_train = add_ngram(x_train, token_indice, ngram_range) x_test = add_ngram(x_test, token_indice, ngram_range) print 'pad sequences (samples x time)' x_train = sequence.pad_sequences(x_train, maxlen=maxlen, padding='post', truncating='post') x_test = sequence.pad_sequences(x_test, maxlen=maxlen, padding='post', truncating='post') logging.info('x_train.shape: %s' % (str(x_train.shape))) print 'build model...' def cal_accuracy(x_test, y_test): """ 准确率统计 """ y_test = np.argmax(y_test, axis=1) y_pred = model.predict_classes(x_test) correct_cnt = np.sum(y_pred==y_test) return float(correct_cnt)/len(y_test) DEBUG = False if DEBUG: model = Sequential() model.add(Embedding(max_features, 200, input_length=maxlen)) model.add(GlobalAveragePooling1D()) model.add(Dropout(0.3)) model.add(Dense(nb_class, activation='softmax')) else: model = load_model('./model/tgind_dalei.h5') #model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) earlystop = EarlyStopping(monitor='val_loss', patience=8) checkpoint = ModelCheckpoint(filepath='./model/tgind_dalei.h5', monitor='val_loss', save_best_only=True, save_weights_only=False) model.fit(x_train, y_train, shuffle=True, batch_size=64, epochs=80, validation_split=0.1, callbacks=[checkpoint, earlystop]) loss, acc = model.evaluate(x_test, y_test) print '\n\nlast model: loss', loss print 'acc', acc model = load_model('model/tgind_dalei.h5') loss, acc = model.evaluate(x_test, y_test) print '\n\n cur best model: loss', loss print 'accuracy', acc logging.info('loss: %.4f ;accuracy: %.4f' % (loss, acc)) logging.info('\nmodel acc: %.4f' % acc) logging.info('\nmodel config:\n %s' % model.get_config())
### test_model.py ###
#!/usr/bin/env python # coding=utf-8 import matplotlib.pyplot as plt from api_tgind import TgIndustry import pandas as pd import codecs import json from collections import OrderedDict ########### 根据阈值计算准确率 ########### def cal_model_acc(model, fname='./data/industry_dalei_test_sample2k.txt', nrows=None): """ 载入数据, 并计算前5的准确率 """ res = {} res['y_pred'] = [] res['y_true'] = [] with codecs.open(fname, encoding='utf-8') as fr: for idx, line in enumerate(fr): tokens = line.strip().split() if len(tokens)>3: tokens, label = tokens[:-1], tokens[-1].replace('__labe__', '') tmp = {} tmp['business'] = ''.join(tokens) res['y_pred'].append(model.predict(tmp)) res['y_true'].append(label) if nrows and idx>nrows: break json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8')) return res def cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv', nrows=None): """ 直接根据csv预测结果 """ res = {} res['y_pred'] = [] res['y_true'] = [] df = pd.read_csv(fname, encoding='utf-8') for idx, item in df.iterrows(): try: res['y_pred'].append(model.predict(item.to_dict())) except Exception as e: print e print idx print item['name'] continue res['y_true'].append(item['label']) if nrows and idx>nrows: break json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8')) return res def get_model_acc_menlei(res, topk=5, threhold=0.8): """ 根据阈值计算模型准确率 """ correct_cnt, total_cnt = 0, 0 for idx, y_pred in enumerate(res['y_pred']): y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # 概率排序 y_pred = OrderedDict() for c, s in y_pred_tuple: y_pred[c] = float(s) if y_pred.values()[0] > threhold: # 最大类别概率大于阈值threhold if res['y_true'][idx][0] in map(lambda x:x[0], y_pred.keys()[:topk]): correct_cnt += 1 total_cnt += 1 acc = float(correct_cnt)/total_cnt recall = float(total_cnt)/len(res['y_true']) return acc, recall def get_model_acc_dalei(res, topk=5, threhold=0.8): """ 根据阈值计算模型准确率 """ correct_cnt, total_cnt = 0, 0 for idx, y_pred in enumerate(res['y_pred']): y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # 概率排序 y_pred = OrderedDict() for c, s in y_pred_tuple: y_pred[c] = float(s) if y_pred.values()[0] >= threhold: # 最大类别概率大于阈值threhold if res['y_true'][idx] in y_pred.keys()[:topk]: correct_cnt += 1 total_cnt += 1 acc = float(correct_cnt)/total_cnt recall = float(total_cnt)/len(res['y_true']) return acc, recall def plot_accuracy(title, df, number): """ 准确率绘图 """ for topk in range(1, 5): tmpdf = df[df.topk==topk] fig = plt.figure() ax1 = fig.add_subplot(111) plt.subplots_adjust(top=0.85) ax1.plot(tmpdf['threhold'], tmpdf['accuracy'], 'ro-', label='accuracy') # ax2 = ax1.twinx() ax1.plot(tmpdf['threhold'], tmpdf['recall'], 'g^-', label='recall') ax1.set_ylim(0.3, 1.0) ax1.legend(loc=3) ax1.set_xlabel('threhold') plt.grid(True) plt.title('%s Industry Classify Result\n topk=%d, number=%d\n' % (title, topk, number)) plt.savefig('log/test_%s_acc_topk%d.png' % (title, topk)) print topk, 'done!' def gen_plot_data(model_acc, ctype='2nd'): """ 生成图数据 """ res = {} res['accuracy'] = [] res['threhold'] = [] res['topk'] = [] res['recall'] = [] for topk in range(1,5): for threhold in range(0, 10): threhold = 0.1*threhold if ctype == '1st': acc, recall = get_model_acc_menlei(model_acc, topk, threhold) else: acc, recall = get_model_acc_dalei(model_acc, topk, threhold) res['accuracy'].append(acc) res['recall'].append(recall) res['threhold'].append(threhold) res['topk'].append(topk) print ctype, topk, acc json.dump(res, open('log/test_model_threshold_%s.log' % ctype, 'wb')) df = pd.DataFrame(res) df.to_csv('log/test_model_result_%s.csv' % ctype, index=False) plot_accuracy(ctype, df, len(model_acc['y_true'])) return df if __name__=='__main__': model = TgIndustry() # model_acc = cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv') model_acc = json.load(codecs.open('log/total_acc_output_12315.json', encoding='utf-8')) gen_plot_data(model_acc, '1st') gen_plot_data(model_acc, '2nd')
### api_tgind.py ###
#!/usr/bin/env python # coding=utf-8 import numpy as np import codecs import simplejson as json from keras.models import load_model from keras.preprocessing import sequence from sklearn.externals import joblib from collections import OrderedDict import pickle as pkl import re, os import jieba import time """ 行业分类调用Api __author__: jkmiao __date__: 2017-07-05 """ class TgIndustry(object): def __init__(self, model_path='model/tgind_dalei_acc76.h5'): base_path = os.path.dirname(__file__) model_path = os.path.join(base_path, model_path) # 载入预训练好的模型 self.model = load_model(model_path) # 载入labelEncoder self.le = joblib.load(os.path.join(base_path, './model/tgind_labelencoder.h5')) # 载入字符映射表 self.char2idx = json.load(open(os.path.join(base_path, 'data/char2idx.json'))) # 载入停用词表 # self.stop_words = set([line.strip() for line in codecs.open('./data/stopwords.txt', encoding='utf-8')]) self.stop_words = pkl.load(open(os.path.join(base_path, './data/stopwords.pkl'))) # 载入类别最终的编号和名称映射 self.menlei_label2name = json.load(open(os.path.join(base_path, 'data/menlei_label2name.json'))) # 一级分类 self.dalei_label2name = json.load(open(os.path.join(base_path, 'data/dalei_label2name.json'))) # 二级分类 def predict(self, company_info, topk=2, firstIndustry=False, final_name=False): """ :type company_info: 公司相关信息 :rtype business: str: 对应 label """ line = '' for key, value in company_info.iteritems(): if key in ['name', 'business']: # 公司信息, 目前取公司名和经营范围 line += company_info[key] if not isinstance(line, unicode): line = line.decode('utf-8') # 去除停用词后的句子 line = ''.join([token for token in jieba.cut(line) if token not in self.stop_words]) data = [self.char2idx[char] for char in line if char in self.char2idx] data = sequence.pad_sequences([data], maxlen=100, padding='post', truncating='post') y_pred_proba = self.model.predict(data, verbose=0) y_pred_idx_list = [c[-topk:][::-1] for c in np.argsort(y_pred_proba, axis=-1)][0] res = OrderedDict() for y_pred_idx in y_pred_idx_list: y_pred_label = self.le.inverse_transform(y_pred_idx) if final_name: y_pred_label = self.dalei_label2name[y_pred_label] if firstIndustry: res[y_pred_label[0]] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数 res[y_pred_label] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数 return res if __name__ == '__main__': DIRTY_LABEL = re.compile('\W+') test = TgIndustry() cnt, total_cnt = 0, 0 start_time = time.time() fw2 = codecs.open('./output/industry_dalei_test_sample2k_error.txt', 'wb', encoding='utf-8') with codecs.open('./data/industry_dalei_test_sample2k.txt', encoding='utf-8') as fr: for idx, line in enumerate(fr): tokens = line.strip().split() if len(tokens)>3: tokens, label = tokens[:-1], tokens[-1].replace('__label__', '') if len(label) not in [2, 3] or DIRTY_LABEL.search(label): print 'error line:' print idx, line, label continue tmp = {} tmp['business'] = ''.join(tokens) y_pred = test.predict(tmp, topk=1) if label in y_pred: cnt += 1 elif y_pred.values()[0] < 0.3: print 'error: ', ''.join(tokens), y_pred, 'y_true:', label fw2.write(''.join(tokens)) total_cnt +=1 print label print json.dumps(y_pred, ensure_ascii=False) print idx, '=='*20, float(cnt)/total_cnt if idx>200: break print 'avg cost time:', float(time.time()-start_time)/idx
每天一小步,人生一大步!Good luck~