如何将训练好的Python模型给JavaScript使用?
2023-04-01 15:40 北桥苏 阅读(60) 评论(0) 编辑 收藏 举报前言
从前面的Tensorflow环境搭建到目标检测模型迁移学习,已经完成了一个简答的扑克牌检测器,不管是从图片还是视频都能从画面中识别出有扑克的目标,并标识出扑克点数。但是,我想在想让他放在浏览器上可能实际使用,那么要如何让Tensorflow模型转换成web格式的呢?接下来将从实践的角度详细介绍一下部署方法!
环境
- Windows10
- Anaconda3
- TensorFlow.js converter
converter介绍
converter全名是TensorFlow.js Converter,他可以将TensorFlow GraphDef模型(通过Python API创建的,可以先理解为Python模型) 转换成Tensorflow.js可读取的模型格式(json格式), 用于在浏览器上对指定数据进行推算。
converter安装
为了不影响前面目标检测训练环境,这里我用conda创建了一个新的Python虚拟环境,Python版本3.6.8。在安装转换器的时候,如果当前环境没有Tensorflow,默认会安装与TF相关的依赖,只需要进入指定虚拟环境,输入以下命令。
1 | pip install tensorflowjs |
converter用法
1 | tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model |
1. 产生的文件(生成的web格式模型)
转换器命令执行后生产两种文件,分别是model.json (数据流图和权重清单)和group1-shard\*of\* (二进制权重文件)
2. 输入的必要条件(命令参数和选项[带--为选项])
converter转换指令后面主要携带四个参数,分别是输入模型的格式,输出模型的格式,输入模型的路径,输出模型的路径,更多帮助信息可以通过以下命令查看,另附命令分解图。
1 | tensorflowjs_converter --help |
2.1. --input_format
要转换的模型的格式,SavedModel 为 tf_saved_model, frozen model 为 tf_frozen_model, session bundle 为 tf_session_bundle, TensorFlow Hub module 为 tf_hub,Keras HDF5 为 keras。
2.2. --output_format
输出模型的格式, 分别有tfjs_graph_model (tensorflow.js图模型,保存后的web模型没有了再训练能力,适合SavedModel输入格式转换),tfjs_layers_model(tensorflow.js层模型,具有有限的Keras功能,不适合TensorFlow SavedModels转换)。
2.3. input_path
saved model, session bundle 或 frozen model的完整的路径,或TensorFlow Hub模块的路径。
2.4. output_path
输出文件的保存路径。
2.5. --saved_model_tags
只对SavedModel转换用的选项:输入需要加载的MetaGraphDef相对应的tag,多个tag请用逗号分隔。默认为 serve
。
2.6. --signature_name
对TensorFlow Hub module和SavedModel转换用的选项:对应要加载的签名,默认为default
。
2.7. --output_node_names
输出节点的名字,每个名字用逗号分离。
3. 常用的两组命令行
1 2 3 4 5 6 | 1. covert from saved_model tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model 2. convert from frozen_model tensorflowjs_converter --input_format=tf_frozen_model --output_node_names= 'num_detections,detection_boxes,detection_scores,detection_classes' ./frozen_inference_graph.pb ./web_modelk |
开始实践
1. 找到通过export_inference_graph.py导出的模型
导出的模型在项目的inference_graph文件夹(models\research\object_detection)里,frozen_inference_graph.pb是 tf_frozen_model输入格式需要的,而saved_model文件夹就是tf_saved_model格式。在当前目录下新建web_model目录,用于存储转换后的web格式的模型。
2. 开始转换
在当前虚拟环境下,进入到inference_graph目录下,输入以下命令,之后就会在web_model生成一个json文件和多个权重文件。
1 | tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model |
3. 浏览器端部署
3.1. 创建一个前端项目,将web_model放入其中。
3.2.编写代码
<!doctype html> <head> <link rel= "stylesheet" href= "tfjs-examples.css" /> <style> canvas {outline: orange 2px solid; margin: 10px 0;} </style> </head> <body> <div class = "tfjs-example-container centered-container" > <section class = 'title-area' > <h1>赌圣2023</h1> </section> <p class = 'section-head' >模型描述</p> <p>我看你怎么出老千!</p> <p class = 'section-head' >模型状态</p> <div id= "status" >加载模型中...</div> <div> <p class = 'section-head' >效果展示</p> <p></button><input type= "file" accept= "image/*" id= "test" /></p> <canvas id= "data-canvas" width= "300" height= "1100" ></canvas> </div> </div> </body> <script src= "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js" ></script> <script> const canvas = document.getElementById( 'data-canvas' ); const status = document.getElementById( 'status' ); const testModel = document.getElementById( 'test' ); const BOUNDING_BOX_LINE_WIDTH = 3; const BOUNDING_BOX_STYLE1 = 'rgb(0,0,255)' ; const BOUNDING_BOX_STYLE2 = 'rgb(0,255,0)' ; async function init() { const LOCAL_MODEL_PATH = './web_model/model.json' ; // 将本地模型保存到浏览器 // tf.sequential().save // 加载本地模型 let model; try { model = await tf.loadGraphModel(LOCAL_MODEL_PATH); testModel.disabled = false ; status.textContent = '成功加载本地模型!请亮出你的牌吧' ; // 默认扑克牌 runAndVisualizeInference( './cam_image39.jpg' , model) } catch (err) { console.log( '加载本地模型错误:' , err); status.textContent = '加载本地模型失败' ; } testModel.addEventListener( 'change' , (e) => { runAndVisualizeInference(e, model) }); } async function runAndVisualizeInference(e, model) { if ( typeof e === 'string' ) { await new Promise((resolve, reject) => { // 图片显示在canvas中 var img = new Image; img.src = e; img.onload = function () { // 必须onload之后再画 let w = 500; let h = img.height/img.width*500; canvas.width = w; canvas.height = h; var ctx = canvas.getContext( '2d' ); ctx.drawImage(img,0,0,w,h); resolve(); } }) } else { // 上传图片并显示在canvas中 var file = e.target.files[0]; if (!/image\/\w+/.test(file.type)) { alert( "请确保文件为图像类型" ); return false ; } var reader = new FileReader(); reader.readAsDataURL(file); // 转化成base64数据类型 await new Promise((resolve, reject) => { reader.onload = function (e) { // 图片显示在canvas中 var img = new Image; img.src = this .result; img.onload = function () { // 必须onload之后再画 let w = 500; let h = img.height/img.width*500; canvas.width = w; canvas.height = h; var ctx = canvas.getContext( '2d' ); ctx.drawImage(img,0,0,w,h); resolve(); } } }) } // 模型输入处理 let image = tf.browser.fromPixels(canvas); const t4d = image.expandDims(0); const outputDim = [ 'num_detections' , 'detection_boxes' , 'detection_scores' , 'detection_classes' ]; const labelMap = { 1: '九点' , 2: '十点' , 3: 'Jack' , 4: 'Queen' , 5: 'King' , 6: 'Ace' } let modelOut = {}, boxes = [], w = canvas.width, h = canvas.height; console.log(model) for ( const dim of outputDim) { let tensor = await model.executeAsync({ 'image_tensor' : t4d }, `${dim}:0`); modelOut[dim] = await tensor.data(); } console.log(modelOut) for ( let i=0; i<modelOut[ 'detection_scores' ].length; i++) { const score = modelOut[ 'detection_scores' ][i]; if (score < 0.5) break ; // 置信度过滤 boxes.push({ ymin: modelOut[ 'detection_boxes' ][i*4]*h, xmin: modelOut[ 'detection_boxes' ][i*4+1]*w, ymax: modelOut[ 'detection_boxes' ][i*4+2]*h, xmax: modelOut[ 'detection_boxes' ][i*4+3]*w, label: labelMap[modelOut[ 'detection_classes' ][i]], }) } console.log(boxes) // 可视化检测框 drawBoundingBoxes(canvas, boxes); // 张量运行内存清除 tf.dispose([image, modelOut]); } function drawBoundingBoxes(canvas, predictBoundingBoxArr) { for ( const box of predictBoundingBoxArr) { let left = box.xmin; let right = box.xmax; let top = box.ymin; let bottom = box.ymax; const ctx = canvas.getContext( '2d' ); ctx.beginPath(); ctx.strokeStyle = box.label=== 'ZERO_DEV' ?BOUNDING_BOX_STYLE1:BOUNDING_BOX_STYLE2; ctx.lineWidth = BOUNDING_BOX_LINE_WIDTH; ctx.moveTo(left, top); ctx.lineTo(right, top); ctx.lineTo(right, bottom); ctx.lineTo(left, bottom); ctx.lineTo(left, top); ctx.stroke(); ctx.font = '24px Arial bold' ; ctx.fillStyle = box.label=== 'zfc' ?BOUNDING_BOX_STYLE2:BOUNDING_BOX_STYLE1; ctx.fillText(box.label, left+8, top+8); } } init(); </script> |
3.3. 运行结果
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律