对抗网络应用于人脸动漫化

1 代码和模型来源于Github(如有侵权请及时评论联系),建议作者原始链接学习

https://github.com/hpc203

2 测试硬件 MAC M1;预先安装 python opencv onnxruntime;注意pip安装的适用于编译运行python版本程序;brew安装的可用于c++程序;由于mac M1通过brew安装的onnxruntime,使用的时候某些函数会出现未定义,建议运行python版本;

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import cv2
import onnxruntime as ort

class AnimeGAN():
    def __init__(self):
        so = ort.SessionOptions()
        so.log_severity_level = 3
        self.net = ort.InferenceSession('face_paint_512_v2_0.onnx', so)
        self.input_size = 512
        self.input_name = self.net.get_inputs()[0].name
        self.output_name = self.net.get_outputs()[0].name
    def detect(self, image):
        img = cv2.resize(image, (self.input_size, self.input_size))
        x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        x = x.transpose(2, 0, 1).astype('float32')
        x = x * 2 - 1
        x = x.reshape(-1, 3, self.input_size, self.input_size)

        outs = self.net.run([self.output_name], {self.input_name: x})[0].squeeze(axis=0)
        outs = (outs * 0.5 + 0.5).clip(0, 1)
        outs = outs * 255
        outs = outs.transpose(1, 2, 0).astype('uint8')
        outs = cv2.cvtColor(outs, cv2.COLOR_RGB2BGR)
        return cv2.resize(outs, (image.shape[1], image.shape[0]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgpath', type=str, default='liushishi.jpg')
    args = parser.parse_args()

    model = AnimeGAN()
    srcimg = cv2.imread(args.imgpath)
    result = model.detect(srcimg)

    cv2.namedWindow('image', cv2.WINDOW_NORMAL)
    cv2.imshow('image', srcimg)
    winName = 'Deep learning AnimeGAN in ONNXRuntime'
    cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
    cv2.imshow(winName, result)
    cv2.waitKey()
    cv2.destroyAllWindows()

 

 

 

posted on 2022-11-28 11:11  邗影  阅读(42)  评论(0编辑  收藏  举报

导航