深度学习(PYTORCH)-2.python调用dlib提取人脸68个特征点

Posted on 2018-03-02 16:26  LOMOoO  阅读(2741)  评论(0编辑  收藏  举报

在看官方教程时,无意中发现别人写的一个脚本,非常简洁。

官方教程地址:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py

使用的是dlib自带的特征点检测库,初期用来测试还是不错的

 

 1 """Create a sample face landmarks dataset.
 2 
 3 Adapted from dlib/python_examples/face_landmark_detection.py
 4 See this file for more explanation.
 5 
 6 Download a trained facial shape predictor from:
 7     http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
 8 """
 9 import dlib
10 import glob
11 import csv
12 from skimage import io
13 
14 detector = dlib.get_frontal_face_detector()
15 predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')
16 num_landmarks = 68
17 
18 with open('face_landmarks.csv', 'w', newline='') as csvfile:
19     csv_writer = csv.writer(csvfile)
20 
21     header = ['image_name']
22     for i in range(num_landmarks):
23         header += ['part_{}_x'.format(i), 'part_{}_y'.format(i)]
24 
25     csv_writer.writerow(header)
26 
27     for f in glob.glob('*.jpg'):
28         img = io.imread(f)
29         dets = detector(img, 1)  # face detection
30 
31         # ignore all the files with no or more than one faces detected.
32         if len(dets) == 1:
33             row = [f]
34 
35             d = dets[0]
36             # Get the landmarks/parts for the face in box d.
37             shape = predictor(img, d)
38             for i in range(num_landmarks):
39                 part_i_x = shape.part(i).x
40                 part_i_y = shape.part(i).y
41                 row += [part_i_x, part_i_y]
42 
43             csv_writer.writerow(row)
View Code

附上使用matplotlib显示特征点的脚本:

 1 from __future__ import print_function, division
 2 import os
 3 import torch
 4 import pandas as pd
 5 from skimage import io, transform
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 from torch.utils.data import Dataset, DataLoader
 9 from torchvision import transforms, utils
10 
11 # Ignore warnings
12 import warnings
13 warnings.filterwarnings("ignore")
14 
15 plt.ion()   # interactive mode
16 
17 landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
18 
19 n = 5
20 img_name = landmarks_frame.iloc[n, 0]
21 landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
22 landmarks = landmarks.astype('float').reshape(-1, 2)
23 
24 print('Image name: {}'.format(img_name))
25 print('Landmarks shape: {}'.format(landmarks.shape))
26 print('First 4 Landmarks: {}'.format(landmarks[:4]))
27 
28 def show_landmarks(image, landmarks):
29     """Show image with landmarks"""
30     plt.imshow(image)
31     plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
32     plt.pause(0.001)  # pause a bit so that plots are updated
33 
34 plt.figure()
35 show_landmarks(io.imread(os.path.join('faces/', img_name)),
36                landmarks)
37 plt.show()
View Code

 效果图: