SHIHUC

好记性不如烂笔头,还可以分享给别人看看! 专注基础算法,互联网架构,人工智能领域的技术实现和应用。
  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

Keras/Tensorflow训练逻辑研究

Posted on 2018-02-28 20:31  shihuc  阅读(15346)  评论(0编辑  收藏  举报

Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/

 

Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!

 

我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。

什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。

 

1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整

1.1 .主要核心代码

A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。

2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。

3)训练函数原型及重要参数解释

def fit_generator(self, generator,        #生成器,一个yield的函数,迭代返回数据
             steps_per_epoch,             #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
             epochs=1,                    #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
             verbose=1,                   #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
             callbacks=None,              #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
             validation_data=None,        #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
             validation_steps=None,       #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
             class_weight=None,           #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
             max_queue_size=10,
             workers=1,
             use_multiprocessing=False,
             initial_epoch=0)

 

B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。

2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。

3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作

 

2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)

回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。

   def __init__(self):
        self.validation_data = None

    def set_params(self, params):
        self.params = params

    def set_model(self, model):
        self.model = model

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_batch_begin(self, batch, logs=None):
        pass

    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        pass

    def on_train_end(self, logs=None):
        pass

 

A. keras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.totals:
                self.totals[k] += v * batch_size
            else:
                self.totals[k] = v * batch_size

在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.params['metrics']:
                if k in self.totals:
                    # Make value available to next callbacks.
                    logs[k] = self.totals[k] / self.seen

 

B. keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                                  ' saving model to %s'
                                  % (epoch, self.monitor, self.best,
                                     current, filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('Epoch %05d: %s did not improve' %
                                  (epoch, self.monitor))
            else:
                if self.verbose > 0:
                    print('Epoch %05d: saving model to %s' % (epoch, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

 

C.keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

 

D. keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

 

3. 具体训练逻辑过程

A. 训练函数分析

a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3)         callbacks.on_epoch_begin(epoch)
4)         while steps_done < steps_per_epoch:
5)             generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6)             callbacks.on_batch_begin(batch_index, batch_logs)
7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8)             callbacks.on_batch_end(batch_index, batch_logs)
            end of while steps_done < steps_per_epoch
            self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
      end of while epoch < epochs
10) callbacks.on_train_end()


b. 特别介绍下train_on_batch
   train_on_batch (keras中的trainning.py)
        |_self._standardize_user_data
        |_self._make_train_function
        |_self.train_function (tensorflow的函数)
                        |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

 

B训练和验证的对比

a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)

日志如下:

141/141 [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450


第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率   
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率

 

b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果

self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

1) while steps_done < steps:
2)           generator_output = next(output_generator)
3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

其中重点test_on_batch

test_on_batch(self, x, y, sample_weight=None)
         |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
         |_self._make_test_function()
         |_self.test_function(ins)                    
                    |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

 

c. train和test的重要区别,应该体现在下面的两个函数上

def _make_train_function(self):
        if not hasattr(self, 'train_function'):
            raise RuntimeError('You must compile your model before using it.')
        if self.train_function is None:
            inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
            if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                inputs += [K.learning_phase()]

            with K.name_scope('training'):
                with K.name_scope(self.optimizer.__class__.__name__):
                    training_updates = self.optimizer.get_updates(
                        params=self._collected_trainable_weights,
                        loss=self.total_loss)
                updates = self.updates + training_updates
                # Gets loss and metrics. Updates weights at each call.
                self.train_function = K.function(inputs,
                                                 [self.total_loss] + self.metrics_tensors,
                                                 updates=updates,
                                                 name='train_function',
                                                 **self._function_kwargs)
def _make_test_function(self):
        if not hasattr(self, 'test_function'):
            raise RuntimeError('You must compile your model before using it.')
        if self.test_function is None:
            inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
            if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                inputs += [K.learning_phase()]
            # Return loss and metrics, no gradient updates.
            # Does update the network states.
            self.test_function = K.function(inputs,
                                            [self.total_loss] + self.metrics_tensors,
                                            updates=self.state_updates,
                                            name='test_function',
                                            **self._function_kwargs)

经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。

结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。

 

总结:

0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新

1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求

2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)

3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新

4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。