1、实现效果
2、相关代码
实现BP训练模型的线程类
1 class WorkThread(QtCore.QThread): 2 finish_trigger = QtCore.pyqtSignal() # 关闭waiting_gif 3 result_trigger = QtCore.pyqtSignal(pd.Series) # 传递预测结果信号 4 evaluate_trigger = QtCore.pyqtSignal(list) # 传递正确率信号 5 6 def __int__(self): 7 super(WorkThread, self).__init__() 8 9 def init(self, dataset, feature, label, info): 10 self.dataset = dataset 11 self.feature = feature 12 self.label = label 13 self.info = info 14 15 # 可以认为,run()函数就是新的线程需要执行的代码 16 def run(self): 17 self.BP() 18 19 def BP(self): 20 """ 21 BP神经网络,返回标签的预测数据 22 :param parent: 23 :param dataset: 24 :param feature: 25 :param label: 26 :param info: 27 :return: 28 """ 29 dataset = self.dataset 30 feature = self.feature 31 label = self.label 32 info = self.info 33 34 input_dim = len(feature) 35 data_x = dataset[feature] # 特征数据 36 data_y = dataset[label] # 标签数据 37 38 x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=info[0][3]) 39 40 # **********************建立一个简单BP神经网络模型********************************* 41 self.model = Sequential() # 声明一个顺序模型 42 count = len(info) 43 for i in range(1, count-1): 44 if i == 1: 45 self.model.add(Dense(info[i][0], activation=info[i][1], input_dim=input_dim, kernel_initializer=info[i][2])) # 输入层,Dense表示BP层 46 else: 47 self.model.add(Dense(info[i][0], activation=info[i][1], kernel_initializer=info[i][2])) 48 49 # 添加输出层 50 self.model.add(Dense(info[count-1][0], activation=info[count-1][1], kernel_initializer=info[count-1][2])) 51 52 sgd = SGD(lr=info[0][0], decay=1e-6, momentum=0.9, nesterov=True) 53 self.model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy']) # 编译模型 54 55 self.model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=info[0][1], batch_size=info[0][2]) # 训练模型1000次 56 57 scores_train = self.model.evaluate(x_train, y_train, batch_size=10) 58 scores_test = self.model.evaluate(x_test, y_test, batch_size=10) 59 scores = self.model.evaluate(data_x, data_y, batch_size=10) 60 61 self.finish_trigger.emit() # 循环完毕后发出信号 62 list = [scores_train[1]*100, scores_test[1]*100, scores[1]*100] 63 self.evaluate_trigger.emit(list) 64 result = pd.Series(self.model.predict(data_x).T[0]) 65 result.name = '预测(BP)' 66 self.result_trigger.emit(result) 67 K.clear_session() # 反复调用model 模型 68 69 def save_model(self, save_dir): 70 self.model.save(save_dir) # 保存模型
GUI显示代码(部分):
1 class MainWindow(QtGui.QMainWindow): 2 save_dir_signal = QtCore.pyqtSignal(str) # 传递保存目录信号 3 4 def show_evaluate_result(self, evaluate_result): 5 help = QtGui.QMessageBox.information(self, '评价结果', 6 "训练集正确率: %.2f%%\n测试集正确率: %.2f%%\n数据集正确率: %.2f%%" % 7 (evaluate_result[0], evaluate_result[1], evaluate_result[2]), 8 QtGui.QMessageBox.Yes) 9 10 self.pop_save_dir() 11 12 def pop_save_dir(self): 13 msg = QtGui.QMessageBox.information(self, '提示', '是否保存模型?', QtGui.QMessageBox.Yes | QtGui.QMessageBox.No) 14 if msg == QtGui.QMessageBox.Yes: 15 save_dir = QtGui.QFileDialog.getSaveFileName(self, '选择保存目录', 'C:\\Users\\fuqia\\Desktop') 16 17 if save_dir != '': 18 save_dir = save_dir + '.model' 19 self.save_dir_signal.emit(save_dir) 20 21 def show_bp_result(self, result): 22 23 self.predict_data = result 24 TableWidgetDeal.add_predict_data(self.table, result) 25 26 def waiting_label_close(self): 27 self.label.close() 28 29 def show_waiting(self): 30 self.label = QtGui.QLabel(self) 31 self.label.setFixedSize(640, 480) # 不加的话有问题??? 32 self.label.setWindowFlags(QtCore.Qt.FramelessWindowHint) # 无边框 33 self.label.setAttribute(QtCore.Qt.WA_TranslucentBackground) # 背景透明 34 35 screen = QtGui.QDesktopWidget().screenGeometry() 36 size = self.label.geometry() 37 # 如果是self.label.move((screen.width() - size.width()) / 2 , (screen.height() - size.height()) / 2)无法居中 38 self.label.move((screen.width() - size.width()) / 2 + 240, (screen.height() - size.height()) / 2) 39 40 # 打开gif文件 41 movie = QtGui.QMovie("./Icon/waiting.gif") 42 # 设置cacheMode为CacheAll时表示gif无限循环,注意此时loopCount()返回-1 43 movie.setCacheMode(QtGui.QMovie.CacheAll) 44 # 播放速度 45 movie.setSpeed(100) 46 self.label.setMovie(movie) 47 # 开始播放,对应的是movie.start() 48 movie.start() 49 self.label.show() 50 q = QtCore.QEventLoop() 51 q.exec_()
1 w = WorkThread() 2 w.init(self.object.data_set, feature, label, self.bp_ui.bp_info) 3 w.start() 4 w.finish_trigger.connect(self.waiting_label_close) 5 w.result_trigger.connect(self.show_bp_result) 6 w.evaluate_trigger.connect(self.show_evaluate_result) 7 self.save_dir_signal.connect(w.save_model) 8 self.show_waiting()