Tensorflow lite Android 人脸检测demo
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
Tensorflow及Object detection API相关环境的搭建安装
https://www.jianshu.com/p/286b8163da29
Bazel安装
以上步骤用于Tensorflow物体检测模型的训练及Tensorflow到Tensorflow lite的模型转换,具体步骤后面再讲。
下载Android Studio 导入Tensorflow目录下tflite Android demo。具体目录在Tensorflow/contrib/lite/example下
下载相关Jar包。。进行编译看看是否能够编译成功,一般来说网络良好,自动下载好各种包后就会编译成功,测试生成的apk,默认的是物体检测。要进行人脸检测,这里需要做的就是把相关模型进行替换。
Demo中的物体检测模型是基于Tensorflow的ssd-mobilenet-quantized模型,此模型是在coco数据集上训练。我们可以使用此模型做迁移学习来得到针对于人脸检测的模型。
预训练模型可以在tensorflow object detection的model zoo中下载。
人脸数据集可以采用WIDER FACE数据集,下载好后利用脚本将图像及标注信息转换为tfrecord格式供训练使用。
从research路径找到对应模型的config文件object_detection/samples/configs/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config
修改其中tfrecord及label的路径基于checkpoint路径后开始训练模型
python train.py \ --logtostderr \ --train_dir=/home/kai/tensorflow/face/ \ --pipeline_config_path=/home/kai/tensorflow/face/ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config
待loss足够小时终止训练,得到checkpoint。至此我们得到了可以做人脸检测的模型,但想要在移动端使用tensorflow lite模型还需要一些额外的工作。
Tensorflow Lite是Google设计一种针对移动端的轻量级深度学习模型,它使用quantized kernel等一系列技术使模型更轻便,更快速,而更适合在移动端上使用。
首先需要将checkpoint转换为Tensorflow lite可用的pb文件
python object_detection/export_tflite_ssd_graph.py \ --pipeline_config_path=$CONFIG_FILE \ --trained_checkpoint_prefix=$CHECKPOINT_PATH \ --output_directory=$OUTPUT_DIR \ --add_postprocessing_op=true
确保用的是export_tflite_ssd_graph而不是export_inference_graph否则得到的pb后面无法转换。
得到tflite_graph.pb后需要利用TOCO将pb模型转换为.tflite模型
在TensorFlow目录下执行
bazel run -c opt tensorflow/contrib/lite/toco:toco -- \ --input_file=$OUTPUT_DIR/tflite_graph.pb \ --output_file=$OUTPUT_DIR/detect.tflite \ --input_shapes=1,300,300,3 \ --input_arrays=normalized_input_image_tensor \ --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \ --inference_type=QUANTIZED_UINT8 \ --mean_values=128 \ --std_values=128 \ --change_concat_input_ranges=false \ --allow_custom_ops
如没有报错则会在OUTPUT_DIR目录下生产一个detect.tflite文件即为tflite模型
在TensorFlow lite demo中添加模型
/* * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.tensorflow.demo; import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Matrix; import android.graphics.Paint; import android.graphics.Paint.Style; import android.graphics.RectF; import android.graphics.Typeface; import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue; import android.widget.Toast; import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Vector; import org.tensorflow.demo.OverlayView.DrawCallback; import org.tensorflow.demo.env.BorderedText; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger; import org.tensorflow.demo.tracking.MultiBoxTracker; import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. /** * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track * objects. */ public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { private static final Logger LOGGER = new Logger(); // Configuration values for the prepackaged SSD face model. private static final int TF_OD_API_INPUT_SIZE = 300; private static final boolean TF_OD_API_IS_QUANTIZED = true; private static final String TF_OD_API_MODEL_FILE = "facedetect.tflite"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/face.txt"; // Configuration values for the prepackaged SSD Normal model. private static final int TF_OD_API_INPUT_SIZE_N = 300; private static final boolean TF_OD_API_IS_QUANTIZED_N = true; private static final String TF_OD_API_MODEL_FILE_N = "detect.tflite"; private static final String TF_OD_API_LABELS_FILE_N = "file:///android_asset/coco_labels_list.txt"; // Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints. private enum DetectorMode { TF_OD_API; } private static final DetectorMode MODE = DetectorMode.TF_OD_API; // Minimum detection confidence to track a detection. private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.4f; private static final boolean MAINTAIN_ASPECT = false; private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); private static final boolean SAVE_PREVIEW_BITMAP = false; private static final float TEXT_SIZE_DIP = 10; private Integer sensorOrientation; // face detector private Classifier detector; // object detector private Classifier detector_n; private long lastProcessingTimeMs; private Bitmap rgbFrameBitmap = null; private Bitmap croppedBitmap = null; private Bitmap cropCopyBitmap = null; private boolean computingDetection = false; private long timestamp = 0; private Matrix frameToCropTransform; private Matrix cropToFrameTransform; //tracker private MultiBoxTracker tracker; private byte[] luminanceCopy; private BorderedText borderedText; @Override public void onPreviewSizeChosen(final Size size, final int rotation) { final float textSizePx = TypedValue.applyDimension( TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); tracker = new MultiBoxTracker(this); int cropSize = TF_OD_API_INPUT_SIZE; // face detector try { detector = TFLiteObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE, TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } // Normal object detector try { detector_n = TFLiteObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE_N, TF_OD_API_LABELS_FILE_N, TF_OD_API_INPUT_SIZE_N, TF_OD_API_IS_QUANTIZED_N); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } previewWidth = size.getWidth(); previewHeight = size.getHeight(); sensorOrientation = rotation - getScreenOrientation(); LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); frameToCropTransform = ImageUtils.getTransformationMatrix( previewWidth, previewHeight, cropSize, cropSize, sensorOrientation, MAINTAIN_ASPECT); cropToFrameTransform = new Matrix(); frameToCropTransform.invert(cropToFrameTransform); trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); trackingOverlay.addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { tracker.draw(canvas); if (isDebug()) { //tracker.drawDebug(canvas); } } }); addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { if (!isDebug()) { return; } final Bitmap copy = cropCopyBitmap; if (copy == null) { return; } final int backgroundColor = Color.argb(100, 0, 0, 0); canvas.drawColor(backgroundColor); final Matrix matrix = new Matrix(); final float scaleFactor = 2; matrix.postScale(scaleFactor, scaleFactor); matrix.postTranslate( canvas.getWidth() - copy.getWidth() * scaleFactor, canvas.getHeight() - copy.getHeight() * scaleFactor); canvas.drawBitmap(copy, matrix, new Paint()); final Vector<String> lines = new Vector<String>(); if (detector_n != null) { final String statString = detector_n.getStatString(); final String[] statLines = statString.split("\n"); for (final String line : statLines) { lines.add(line); } } lines.add(""); lines.add("Frame: " + previewWidth + "x" + previewHeight); lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); lines.add("Rotation: " + sensorOrientation); lines.add("Inference time: " + lastProcessingTimeMs + "ms"); borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); } }); } OverlayView trackingOverlay; @Override protected void processImage() { ++timestamp; final long currTimestamp = timestamp; byte[] originalLuminance = getLuminance(); tracker.onFrame( previewWidth, previewHeight, getLuminanceStride(), sensorOrientation, originalLuminance, timestamp); trackingOverlay.postInvalidate(); // No mutex needed as this method is not reentrant. if (computingDetection) { readyForNextImage(); return; } computingDetection = true; LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread."); rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); if (luminanceCopy == null) { luminanceCopy = new byte[originalLuminance.length]; } System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); readyForNextImage(); final Canvas canvas = new Canvas(croppedBitmap); canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); // For examining the actual TF input. if (SAVE_PREVIEW_BITMAP) { ImageUtils.saveBitmap(croppedBitmap); } runInBackground( new Runnable() { @Override public void run() { LOGGER.i("Running detection on image " + currTimestamp); final long startTime = SystemClock.uptimeMillis(); final List<Classifier.Recognition> results_n = detector_n.recognizeImage(croppedBitmap); final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); final Canvas canvas = new Canvas(cropCopyBitmap); final Paint paint = new Paint(); paint.setColor(Color.RED); paint.setStyle(Style.STROKE); paint.setStrokeWidth(2.0f); float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; switch (MODE) { case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break; } final List<Classifier.Recognition> mappedRecognitions = new LinkedList<Classifier.Recognition>(); final List<Classifier.Recognition> mappedRecognitions_n = new LinkedList<Classifier.Recognition>(); boolean faceflag = true; for (final Classifier.Recognition result : results) { final RectF location = result.getLocation(); if (location != null && result.getConfidence() >= minimumConfidence) { //canvas.drawRect(location, paint); faceflag = false; cropToFrameTransform.mapRect(location); result.setLocation(location); mappedRecognitions.add(result); } } if(faceflag) { for (final Classifier.Recognition result_n : results_n) { final RectF location = result_n.getLocation(); String temp = result_n.getTitle(); if (location != null && result_n.getConfidence() >= minimumConfidence && result_n.getTitle().equals("oven") ) { //canvas.drawRect(location, paint); cropToFrameTransform.mapRect(location); result_n.setLocation(location); mappedRecognitions_n.add(result_n); } } } tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); tracker.trackResults(mappedRecognitions_n, luminanceCopy, currTimestamp); trackingOverlay.postInvalidate(); requestRender(); computingDetection = false; } }); } @Override protected int getLayoutId() { return R.layout.camera_connection_fragment_tracking; } @Override protected Size getDesiredPreviewFrameSize() { return DESIRED_PREVIEW_SIZE; } @Override public void onSetDebug(final boolean debug) { detector.enableStatLogging(debug); detector_n.enableStatLogging(debug); } }