[Web CV] MobileNet Optimization

TensorFlow.js 入门指南


图片输入

 Photo --> Tensor

//convert to tensor     
const tensor = tf.fromPixels(imageData);    

 

 

模型剖析

复制代码
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0         // layer 0
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 226, 226, 3)       0         
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 8)       216       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 8)       32        
_________________________________________________________________
conv1_relu (ReLU)            (None, 112, 112, 8)       0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 8)       72        
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 8)       32        
_________________________________________________________________
conv_dw_1_relu (ReLU)        (None, 112, 112, 8)       0         
_________________________________________________________________
conv_pw_1 (Conv2D)           (None, 112, 112, 16)      128       
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 112, 112, 16)      64        
_________________________________________________________________
conv_pw_1_relu (ReLU)        (None, 112, 112, 16)      0         // layer 10
_________________________________________________________________
conv_pad_2 (ZeroPadding2D)   (None, 114, 114, 16)      0         
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D)  (None, 56, 56, 16)        144       
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 56, 56, 16)        64        
_________________________________________________________________
conv_dw_2_relu (ReLU)        (None, 56, 56, 16)        0         
_________________________________________________________________
conv_pw_2 (Conv2D)           (None, 56, 56, 32)        512       
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 56, 56, 32)        128       
_________________________________________________________________
conv_pw_2_relu (ReLU)        (None, 56, 56, 32)        0         
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D)  (None, 56, 56, 32)        288       
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 56, 56, 32)        128       
_________________________________________________________________
conv_dw_3_relu (ReLU)        (None, 56, 56, 32)        0         // layer 20
_________________________________________________________________
conv_pw_3 (Conv2D)           (None, 56, 56, 32)        1024      
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 56, 56, 32)        128       
_________________________________________________________________
conv_pw_3_relu (ReLU)        (None, 56, 56, 32)        0         
_________________________________________________________________
conv_pad_4 (ZeroPadding2D)   (None, 58, 58, 32)        0         
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D)  (None, 28, 28, 32)        288       
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 28, 28, 32)        128       
_________________________________________________________________
conv_dw_4_relu (ReLU)        (None, 28, 28, 32)        0         
_________________________________________________________________
conv_pw_4 (Conv2D)           (None, 28, 28, 64)        2048      
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 28, 28, 64)        256       
_________________________________________________________________
conv_pw_4_relu (ReLU)        (None, 28, 28, 64)        0         // layer 30
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D)  (None, 28, 28, 64)        576       
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 28, 28, 64)        256       
_________________________________________________________________
conv_dw_5_relu (ReLU)        (None, 28, 28, 64)        0         
_________________________________________________________________
conv_pw_5 (Conv2D)           (None, 28, 28, 64)        4096      
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 28, 28, 64)        256       
_________________________________________________________________
conv_pw_5_relu (ReLU)        (None, 28, 28, 64)        0         
_________________________________________________________________
conv_pad_6 (ZeroPadding2D)   (None, 30, 30, 64)        0         
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D)  (None, 14, 14, 64)        576       
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 14, 14, 64)        256       
_________________________________________________________________
conv_dw_6_relu (ReLU)        (None, 14, 14, 64)        0         // layer 40
_________________________________________________________________
conv_pw_6 (Conv2D)           (None, 14, 14, 128)       8192      
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_pw_6_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D)  (None, 14, 14, 128)       1152      
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_dw_7_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_pw_7 (Conv2D)           (None, 14, 14, 128)       16384     
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_pw_7_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D)  (None, 14, 14, 128)       1152      // layer 50
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_dw_8_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_pw_8 (Conv2D)           (None, 14, 14, 128)       16384     
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_pw_8_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D)  (None, 14, 14, 128)       1152      
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 14, 14, 128)       512       
_________________________________________________________________
conv_dw_9_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_pw_9 (Conv2D)           (None, 14, 14, 128)       16384     
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 14, 14, 128)       512       // layer 60
_________________________________________________________________
conv_pw_9_relu (ReLU)        (None, 14, 14, 128)       0         
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 128)       1152      
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 14, 14, 128)       512       
_________________________________________________________________
conv_dw_10_relu (ReLU)       (None, 14, 14, 128)       0         
_________________________________________________________________
conv_pw_10 (Conv2D)          (None, 14, 14, 128)       16384     
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 14, 14, 128)       512       
_________________________________________________________________
conv_pw_10_relu (ReLU)       (None, 14, 14, 128)       0         
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 128)       1152      
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 14, 14, 128)       512       
_________________________________________________________________
conv_dw_11_relu (ReLU)       (None, 14, 14, 128)       0         // layer 70
_________________________________________________________________
conv_pw_11 (Conv2D)          (None, 14, 14, 128)       16384     
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 14, 14, 128)       512       
_________________________________________________________________
conv_pw_11_relu (ReLU)       (None, 14, 14, 128)       0         
_________________________________________________________________
conv_pad_12 (ZeroPadding2D)  (None, 16, 16, 128)       0         
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 128)         1152      
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 7, 7, 128)         512       
_________________________________________________________________
conv_dw_12_relu (ReLU)       (None, 7, 7, 128)         0         
___________________________________________1______________________
conv_pw_12 (Conv2D)          (None, 7, 7, 256)         32768     
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 7, 7, 256)         1024      
_________________________________________________________________
conv_pw_12_relu (ReLU)       (None, 7, 7, 256)         0         // layer 80
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 256)         2304      
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 7, 7, 256)         1024      
_________________________________________________________________
conv_dw_13_relu (ReLU)       (None, 7, 7, 256)         0         
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 7, 7, 256)         65536     
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 7, 7, 256)         1024      
_________________________________________________________________
conv_pw_13_relu (ReLU)       (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 100)               25700     // layer 88
_________________________________________________________________
dropout_1 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 50)                5050      // layer 90
_________________________________________________________________
dense_3 (Dense)              (None, 4)                 204       
=================================================================
Total params: 249,498
Trainable params: 30,954
Non-trainable params: 218,544
复制代码

到 conv_pw_13_relu 这里,共有87层.若再加上一个output,则为88层次.

上图中的输出部分,竟然有五层.

 

 

可训练层

固定层的输出

首先,我们需要摆脱模型的密集层。

const layer = mobilenet.getLayer('conv_pw_13_relu');

现在让我们更新我们的模型,让这个图层成为一个输出.

mobilenet = tf.model({inputs: mobilenet.inputs, outputs: layer.output});

 

可训练层的输入

最后,我们创建了可训练模型,但我们需要知道最后一层输出形状:

//this outputs a layer of size [null, 7, 7, 256]
const layerOutput = layer.output.shape;   

知道了shape后,开始设置"参数可变"层.

复制代码
1    trainableModel = tf.sequential({    
2        layers: [    
3            tf.layers.flatten({inputShape: [7, 7, 256]}),    
4            tf.layers.dense({    
5                units: 100,    
6                activation: 'relu',    
7                kernelInitializer: 'varianceScaling',    
8                useBias: true    
9            }),    
10          tf.layers.dense({    
11                units: 2,    
12                kernelInitializer: 'varianceScaling',    
13                useBias: false,    
14                activation: 'softmax'    
15            })    
16        ]    
17    });  
复制代码

 

最后的组装

输出变为下一层的输入.

const activation  = mobilenet.predict(input);
const predictions = trainableModel.predict(activation); 

 

 

 

 

[前端智能系列] 纯前端(TF.js)实现扫五福功能


重要指标

最终压缩后的体积在400kb左右,极限压缩在200kb左右(会损失一点模型性能)。

在mac上的性能在30fps左右;

安卓高端机性能在15~20fps左右;

 

 

模型训练

tfjs模型同构是指模型的训练和模型的端上推理用的也是同一套框架及代码。即,使用tfjs在服务端离线训练,在客户端用tfjs部署模型。

仅仅是想证明tfjs有训练模型的能力而已,真正的生产环境还是建议用tensorflow for py去训练好了,原因不再赘述。

两层dense层

conv_pw_13_relu层后面加了两层dense层,最后加上一个softmax激活函数。这里面有一个小技巧,即把MobileNet的预训练参数进行冻结,仅仅训练我们新加的两层dense层,这样能够大大提升模型的收敛速度。最后我们把两个模型融合,得到我们的五福模型。

 

模型瘦身

这时候我们得到的模型大小为4mb左右,做为需要在端上部署的模型,这个体积是无法接受的。

conv_pw_13_relu的输出大小为[7,7,256],而我们的模型直接在后面加了一个flatten层,导致参数量较大.

解决方法:

(1) 加一个池化层,将模型体积减少到了900多kb.

(2) 再用tfjs-converter进行下模型压缩,最终得到了400kb 体积的模型。

 

 

部署过程

部署过程较为简单,将我们的模型文件 (model.jsonweights.bin)上传到cdn,在客户端通过tfjs将模型load到内存中,就能愉快地进行五福识别了~ 

复制代码
我们是蚂蚁金服人工智能部前端团队(AI-Team),致力于开发最好的AI服务应用及前端智能应用。欢迎各位有志之士加入,一起探索前端智能的未来。

有兴趣的请联系我:yaohua.cyh@antfin.com

 
https://juejin.im/post/5e3766db5188254e1b0c7062
作者介绍
复制代码

 

posted @   郝壹贰叁  阅读(266)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示