facenet应用记录

准备

安装facenet-pytorch及其相关包
pip install facenet-pytorch
模型下载,提取码: 74s4
vggface2_*.pt,两个文件存放于~/.cache/torch/checkpoints/

代码

from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import PIL
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)


def create_face_db(db_name, db_path, model, align_model):
    """
    # 创建人脸数据库
    Args:
        db_name (string): 数据库文件名,npz格式
        db_path (string): 图片文件夹,按人名命名子文件夹,所有图片按人名归类
    """

    def collate_fn(x):
        return x[0]

    dataset = datasets.ImageFolder(db_path)
    dataset.idx_to_class = {i: c for c, i in dataset.class_to_idx.items()}
    loader = DataLoader(dataset, collate_fn=collate_fn)

    aligned = []
    names = []
    for x, y in loader:
        x_aligned, prob = align_model(x, return_prob=True)
        if x_aligned is not None:
            print('Face detected with probability: {:8f}'.format(prob))
            aligned.append(x_aligned)
            names.append(dataset.idx_to_class[y])

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    aligned = torch.stack(aligned).to(device)
    embeddings = model(aligned).detach().cpu().numpy()
    np.savez(db_name, embed=embeddings, names=np.array(names))


def who(img_name, db_name, model, align_model):
    x_aligned = None
    with open(img_name, 'rb') as f:
        img = PIL.Image.open(f)
        img = img.convert('RGB')

        x_aligned, prob = align_model(img, return_prob=True)
        if x_aligned is not None:
            print('Face detected with probability: {:8f}'.format(prob))

    f = np.load(db_name)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    embeddings2 = model(x_aligned.unsqueeze(0).to(device)).detach().cpu().numpy()
    dists = [[np.linalg.norm(e1 - e2) for e2 in embeddings2] for e1 in f['embed']]
    # print(dists)
    index = np.argmin(dists)

    # e2tensor = torch.from_numpy(f['embed'])
    # e1tensor = torch.from_numpy(embeddings2)
    # dists = [[(e1 - e2).norm().item() for e2 in e1tensor] for e1 in e2tensor]
    # print(dists)
    return f['names'][index], dists[index]


# 创建人脸数据库,如果没有新的图片加入,可不用执行
create_face_db('db.npz', 'test_images', resnet, mtcnn)
# 查询图片对应人名,
name, dist = who('WP_000098.jpg', 'db.npz', resnet, mtcnn)
print('it\'s %s, dist:%s' % (name, dist))

posted on 2020-07-05 18:46  haskell  阅读(284)  评论(0编辑  收藏  举报