spark上的深度学习——按照雅虎的做法,本质上就是rdd.pipe,推理部分直接代理给tensorflow
from:https://juejin.im/post/5ad4b620f265da23a04a0ad0 看原文代码即可知道本质
Deep Learning On Spark
经过刚才的介绍,我们知道spark是一个分布式的通用计算框架,而以tensorflow为代表的deep learning是一个分布式模型训练框架,它更多专注在梯度计算,那为什么要将两者整合呢?整合的意义在哪里?意义就是能实现更好的分布式训练和数据传输。
针对分布式训练的场景,雅虎开源了TensorflowOnSpark的开源框架,它主要实现tensorflow能够与spark相结合做分布式训练。同时也有其它的一些机制,例如,CaffeOnSpark、MMLSpark(CNTK)、PaddleOnSpark。
TensorflowOnSpark解决的核心问题是将spark作为分布式tensorflow的底层调动机制,通过spark executor去把tensorflow的进程调动起来,这样在进行tensorflow训练时就不需要手动地去组建网络。它也提供了一个API,通过调TFCluster.run这样一个API,可以快速获得tensorflow的一个分布式训练环境。
除此之外TensorflowOnSpark还提供了基于RDD的数据并行机制,如下图所示。这套机制非常方便地集成了spark已有的RDD处理机制,可以更好地跟spark sql或spark streaming去做相应的集成。
然后进入到另外一个方向,叫做spark-deep-learning,是由spark的创始公司—Data Bricks发起的,它主要的目标是提供一些high-level的API,把底层的模型进行组件化,同时它期望可以兼容底层深度式学习框架。
这里有个“Transfer Learning as a Pipeline”的例子供大家了解,如下图所示:
TensorflowOnSpark Pipeline开发了两个API,一个是TFEstimator,另一个是TFModel,提供了这两个之后,你可以直接把它们集成到spark-deep-learning pipeline里面,进行进一步的训练。
六.TensorflowOnSpark案例实践
最后一部分,我们来进行案例实践介绍,我们要解决的是一个图像分类问题,这里采用了一个kaggle dataset,叫做花朵识别,有5个类别,4000多张图片,包括郁金香、太阳花、蒲公英、玫瑰和雏菊这五种花。把这些数据预先存储于MongoDB中。我们的案例实践是一个分布式解决方案,包括分布式数据获取、分布式训练、分布式评估。
以下几张图片是代码示例,简单了解一下:
下面是效果演示,左边是图片,右边是模型预测结果,预测结果都是一个概率值,根据概率值的大小来判定这是哪一类花朵: