生活会辜负努力的人,但不会辜负一直努力的人

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::

1、实现效果

2、相关代码

实现BP训练模型的线程类

 1 class WorkThread(QtCore.QThread):
 2     finish_trigger = QtCore.pyqtSignal()  # 关闭waiting_gif
 3     result_trigger = QtCore.pyqtSignal(pd.Series)  # 传递预测结果信号
 4     evaluate_trigger = QtCore.pyqtSignal(list)  # 传递正确率信号
 5 
 6     def __int__(self):
 7         super(WorkThread, self).__init__()
 8 
 9     def init(self, dataset, feature, label, info):
10         self.dataset = dataset
11         self.feature = feature
12         self.label = label
13         self.info = info
14 
15     # 可以认为,run()函数就是新的线程需要执行的代码
16     def run(self):
17         self.BP()
18 
19     def BP(self):
20         """
21         BP神经网络,返回标签的预测数据
22         :param parent:
23         :param dataset:
24         :param feature:
25         :param label:
26         :param info:
27         :return:
28         """
29         dataset = self.dataset
30         feature = self.feature
31         label = self.label
32         info = self.info
33 
34         input_dim = len(feature)
35         data_x = dataset[feature]  # 特征数据
36         data_y = dataset[label]  # 标签数据
37 
38         x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=info[0][3])
39     
40         # **********************建立一个简单BP神经网络模型*********************************
41         self.model = Sequential()  # 声明一个顺序模型
42         count = len(info)
43         for i in range(1, count-1):
44             if i == 1:
45                 self.model.add(Dense(info[i][0], activation=info[i][1], input_dim=input_dim, kernel_initializer=info[i][2]))  # 输入层,Dense表示BP层
46             else:
47                 self.model.add(Dense(info[i][0], activation=info[i][1], kernel_initializer=info[i][2]))
48 
49         # 添加输出层
50         self.model.add(Dense(info[count-1][0], activation=info[count-1][1], kernel_initializer=info[count-1][2]))
51 
52         sgd = SGD(lr=info[0][0], decay=1e-6, momentum=0.9, nesterov=True)
53         self.model.compile(loss='binary_crossentropy',  optimizer=sgd,  metrics=['accuracy'])  # 编译模型
54 
55         self.model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=info[0][1], batch_size=info[0][2])  # 训练模型1000次
56 
57         scores_train = self.model.evaluate(x_train, y_train, batch_size=10)
58         scores_test = self.model.evaluate(x_test, y_test, batch_size=10)
59         scores = self.model.evaluate(data_x, data_y, batch_size=10)
60 
61         self.finish_trigger.emit()         # 循环完毕后发出信号
62         list = [scores_train[1]*100, scores_test[1]*100, scores[1]*100]
63         self.evaluate_trigger.emit(list)
64         result = pd.Series(self.model.predict(data_x).T[0])
65         result.name = '预测(BP)'
66         self.result_trigger.emit(result)
67         K.clear_session()  # 反复调用model 模型
68 
69     def save_model(self, save_dir):
70         self.model.save(save_dir)  # 保存模型

GUI显示代码(部分):

 1 class MainWindow(QtGui.QMainWindow):
 2     save_dir_signal = QtCore.pyqtSignal(str)  # 传递保存目录信号
 3 
 4 def show_evaluate_result(self, evaluate_result):
 5         help = QtGui.QMessageBox.information(self, '评价结果',
 6                                              "训练集正确率:  %.2f%%\n测试集正确率:  %.2f%%\n数据集正确率:  %.2f%%" %
 7                                              (evaluate_result[0], evaluate_result[1], evaluate_result[2]),
 8                                              QtGui.QMessageBox.Yes)
 9 
10         self.pop_save_dir()
11 
12     def pop_save_dir(self):
13         msg = QtGui.QMessageBox.information(self, '提示', '是否保存模型?', QtGui.QMessageBox.Yes | QtGui.QMessageBox.No)
14         if msg == QtGui.QMessageBox.Yes:
15                 save_dir = QtGui.QFileDialog.getSaveFileName(self, '选择保存目录', 'C:\\Users\\fuqia\\Desktop')
16 
17                 if save_dir != '':
18                     save_dir = save_dir + '.model'
19                     self.save_dir_signal.emit(save_dir)
20 
21     def show_bp_result(self, result):
22 
23         self.predict_data = result
24         TableWidgetDeal.add_predict_data(self.table, result)
25 
26     def waiting_label_close(self):
27         self.label.close()
28 
29     def show_waiting(self):
30         self.label = QtGui.QLabel(self)
31         self.label.setFixedSize(640, 480)  # 不加的话有问题???
32         self.label.setWindowFlags(QtCore.Qt.FramelessWindowHint)  # 无边框
33         self.label.setAttribute(QtCore.Qt.WA_TranslucentBackground)  # 背景透明
34 
35         screen = QtGui.QDesktopWidget().screenGeometry()
36         size = self.label.geometry()
37         # 如果是self.label.move((screen.width() - size.width()) / 2 , (screen.height() - size.height()) / 2)无法居中
38         self.label.move((screen.width() - size.width()) / 2 + 240, (screen.height() - size.height()) / 2)
39 
40         # 打开gif文件
41         movie = QtGui.QMovie("./Icon/waiting.gif")
42         # 设置cacheMode为CacheAll时表示gif无限循环,注意此时loopCount()返回-1
43         movie.setCacheMode(QtGui.QMovie.CacheAll)
44         # 播放速度
45         movie.setSpeed(100)
46         self.label.setMovie(movie)
47         # 开始播放,对应的是movie.start()
48         movie.start()
49         self.label.show()
50         q = QtCore.QEventLoop()
51         q.exec_()

 

1 w = WorkThread()
2 w.init(self.object.data_set, feature, label, self.bp_ui.bp_info)
3 w.start()
4 w.finish_trigger.connect(self.waiting_label_close)
5 w.result_trigger.connect(self.show_bp_result)
6 w.evaluate_trigger.connect(self.show_evaluate_result)
7 self.save_dir_signal.connect(w.save_model)
8 self.show_waiting()

 

posted on 2018-06-16 22:58  何许亻也  阅读(1005)  评论(0编辑  收藏  举报