【机器学习】基于AnimeGAN的漫画人脸生成系统
链接
https://github.com/WanYongyi/machine-learning?tab=readme-ov-file
一、课题描述
1.1 课题背景
1.1.1 论文背景
动画是一种被广泛应用于广告、电影、儿童教育的日常生活中常见的艺术形式。目前,动画的制作主要依靠手工实现。然而,手工制作动画是非常费力的,需要大量的艺术技巧。对于动画艺术家来说,创作高质量的动画作品需要考虑线条、纹理、颜色、阴影等在内的要素,导致创作作品的难度和耗时。因此,能够将真实世界的照片自动转换成高质量的动画风格图像的自动化技术是非常有价值的。
目前,基于深度学习的图像到图像翻译已经取得了很好的效果。近年来,基于学习风格迁移方法[1]-[2]已经成为常见的图像到图像翻译方法。生成对抗网络(GAN)也广泛应用于风格迁移,虽然取得了一定的成功,但也存在着许多明显的问题,主要包括以下几点:1)生成的图像没有明显的动画风格纹理;2)生成的图像失去了原照片的内容;3)大量的网络参数需要较多的内存容量。
1.1.2 论文创新点
Chen等人的工作[3]主要解决了上述背景中存在的三点问题,论文的创新点主要有以下三点:
第一,提出了一种新的轻量级GAN,称为AnimeGAN,它可以将真实世界的照片快速转换为高质量的动画图像。提出的AnimeGAN是一种轻量级的生成对抗模型,具有较少的网络参数,并引入Gram矩阵[4]来获取更生动的风格图像。需要一组照片和一组动画图像进行训练。为了产生高质量的结果,同时使训练数据容易获取,采用了未配对数据进行训练,这意味着训练集中的照片和动画图像在内容上是不相关的。
第二,为了进一步提高生成图像的动画视觉效果,提出了三种新的简单有效的损失函数。提出的损失函数是灰度风格损失、颜色重建损失和灰度对抗损失。在生成网络中,灰度风格损失和颜色重建损失使生成的图像具有更明显的动画风格,并保留了照片的颜色。鉴别器网络中的灰度对抗损失使生成的图像具有鲜艳的色彩。在判别器网络中,则使用了Chen等人文章[5]中提出的促进边缘的对抗损失来保持清晰的边缘。
第三,为了让生成的图像具有原始照片的内容,引入预训练的VGG19[6]作为感知网络,获得生成图像和原始照片的深度感知特征的L1损失。在AnimeGAN开始训练之前对生成器进行初始化训练,使AnimeGAN的训练更稳定。
1.2 目的
机器学习是计算机及其相关学科的一门重要的学科课程,也是人工智能重要学科分支。旨在通过阅读文章,复现模型,研究和应用扩展来完成指定的功能,强化机器学习设计能力,培养独立查找资料能力、自学能力和运用所学知识解决新问题的能力,提高综合素质,并通过课程设计进一步加强对所学知识的理解,进一步提高调用、设计、开发机器学习系统模块以解决实际问题的能力。
1.3 技术要求
主要包括以下几个方面:
-
编程语言:需要掌握至少一门编程语言,如Python、MATLAB或Java等,以便实现机器学习算法和数据处理。
-
数据处理技能:需要具备数据清洗、数据预处理和数据转换等方面的技能,以便准备训练和测试数据集。
-
机器学习算法:需要了解和掌握各种机器学习算法,如分类算法、聚类算法、回归算法等,并能够根据实际需求选择合适的算法。
-
模型评估:需要掌握模型评估的方法和技术,如交叉验证、准确率、召回率、F1值等,以便对机器学习模型进行性能评估和优化。
-
工具和库:需要了解和使用一些常用的机器学习和数据处理工具和库,如Scikit-learn、TensorFlow、Keras、Pandas等,以便提高数据处理和机器学习算法实现的效率和精度。
-
理论知识和数学基础:需要具备一定的理论知识和数学基础,如统计学、线性代数、概率论等,以便更好地理解和应用机器学习算法。
1.4 课题完成要求
1.4.1 主要任务
-
阅读文献,了解文献的实验方法、算法、创新点;
-
复现代码,掌握代码的核心部分:损失函数L(G,D)的计算方法;
-
获取数据集,通过爬虫技术在视觉中国等平台获取数据集用于测试算法;
-
设计UI,开发一个前端界面,可以展示我们的训练效果,开放用户上传自行转换。
1.4.2 具体要求
-
进行课题的前期准备,要求理解课题内容和要求,并开展资料收集和分析、设计工作;
-
在课题总体设计基础上,编写程序的各个子功能模块,调试程序并运行;
-
撰写课程设计报告,要求报告结构合理,格式规范。
二、系统总体设计
2.1 包
模型的包
1、argparse:
作用:用于解析命令行参数。
解释:
argparse.ArgumentParser():创建一个参数解析器。
parser.add_argument():定义脚本接受的命令行参数。
parser.parse_args():解析命令行参数并返回一个包含参数的命名空间。
2、torch:
作用:PyTorch深度学习框架。
解释:
torch.backends.cudnn.enabled:禁用cuDNN加速。
torch.backends.cudnn.benchmark:使用cuDNN的性能优化。
torch.backends.cudnn.deterministic:确保cuDNN的结果是确定性的。
3、cv2 (OpenCV):
作用:用于图像处理。
解释:
cv2.imread():读取图像。
cv2.cvtColor():进行颜色空间转换。
cv2.resize():调整图像大小。
cv2.imwrite():保存图像。
4、numpy:
作用:用于处理数值数组。
解释:
np.float32:32位浮点数。
np.uint8:8位无符号整数。
数组操作,如 astype()、clip()。
5、os:
作用:提供与操作系统交互的功能。
解释:
os.makedirs():创建目录。
os.path.join():连接路径。
6、model (custom module):
作用:包含自定义的模型。
解释:
Generator:生成器模型,可能是一个GAN的生成器。
前端的包
Django项目配置:
python==3.9
django==3.2.1
mysqlclient == 2.2.0
mysql: mysql-5.7.41-winx64
模型运行配置:
torch==2.1.1+cu121
numpy==1.26.0
opencv-python==4.8.1.78
2.2 算法、原理
2.2.1 AnimeGAN架构
架构由两个卷积网络组成,一个是生成器G,用于将显示场景的照片转换成动画图像;另一个是鉴别器D,用于区分图像是来自真实目标域还是生成器产出的输出。以下图是论文中对架构的阐释。
生成器(G)中,所有框上的数字表示通道数量,SUM表示元素总和;鉴别器(D)中,K为核的大小,C为特征映射个数,S为卷积层的步幅,Inst Norm表示实例归一化层。
关于上图的一些解释:生成器可以看做一个对称的编码器-解码器网络,由标准卷积、深度可分离卷积、倒残差块(IRBs)、上下采样模块组成。在生成器中,具有1×1卷积核的最后一个卷积层不用归一化,后面是tanh非线性激活函数。
论文中还提到了其小模块,如下图所示:
2.2.2 损失函数
生成器损失函数主要分为4个部分,不同的损失有不同的权重系数,公式有六个,如下图:
公式(1)中,对抗损失(adv)是生成器G中影响动画转换过程的对抗性损失,内容损失(con)是帮助生成的图像保留输入照片内容的内容丢失,灰度风格损失(gra)是使生成的图像在纹理和线条上具有清晰的动漫风格,颜色重建损失(col)是使生成的图像具有原照片的颜色。[7]
如上图公式(2)(3),对于内容丢失和灰度风格丢失,使用预先训练好的VGG作为感知网络,提取图像的高级语义特征;公式(4),关于颜色的提取和转换,将RGB通道转换为YUV通道,然后对不同通道使用不同的损失计算方法;公式(5),是最终生成器的损失函数;公式(6)则是鉴别器使用的损失函数,除了引入CartoonGAN提出的促进边缘的对抗损失,还使用了一种新的灰度对抗损失,防止生成的图像以灰度图像形式显示。
2.3 流程图
2.3.1 算法流程图
2.3.2 爬虫流程图
2.3.3 系统用例图
2.3.4 Django框架图
2.4 各模块属性及申明
2.4.1 爬虫
1.from urllib import request:导入Python内置库urllib的request模块,用于发送http请求和获取响应数据;
2.import os:导入Python内置库os,用于提供与操作系统交互的功能;
3.import re:导入Python内置库re,用于支持正则表达式的操作;
4.def get_path(classname, subclassname, filename):定义一个名为get_path的函数,用于生成图像文件的保存路径,接受类别名称、子类名称和文件名作为参数,返回一个表示图像文件路径的字符串;
5.url:通过拼接字符串的方式构建表示要访问的网址;
6.header:定义一个HTTP请求头,模拟浏览器User-Agent信息,以便发送请求时能够得到正确的响应;
2.4.2 模型
1.设置PyTorch的cuDNN行为:
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled
:禁用cuDNN加速。
torch.backends.cudnn.benchmark
:使用cuDNN的性能优化。
torch.backends.cudnn.deterministic
:确保cuDNN的结果是确定性的。
2.加载图像的函数 load_image:
def load_image(image_path, x32=False):
# ... (函数的实现)
load_image
函数加载图像,进行颜色空间转换,并可选择是否调整图像大小。
参数:
image_path
:图像文件路径。
x32
:布尔值,表示是否将图像大小调整为32的倍数。
3.测试函数 test
:
def test(args):
# ... (函数的实现)
test
函数进行漫画风格转换的测试。
参数:
args
:包含命令行参数的命名空间。
主要步骤:
加载生成器模型。
遍历输入目录中的图像。
对每个图像进行处理并保存结果。
4.命令行参数的解析:
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='path/to/your/model.pth')
parser.add_argument('--input_dir', type=str, default='path/to/your/input/images')
parser.add_argument('--output_dir', type=str, default='path/to/your/output/directory')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--upsample_align', type=bool, default=False)
parser.add_argument('--x32', action="store_true")
args = parser.parse_args()
argparse.ArgumentParser()
:创建参数解析器。
parser.add_argument()
:定义命令行参数。
args
:解析后的命令行参数。
5.模型加载和推理:
net = Generator()
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
net.to(device).eval()
Generator()
:创建生成器模型的实例。
torch.load(args.checkpoint, map_location="cpu")
:加载预训练模型的权重。
net.to(device).eval()
:将模型移动到指定的设备并设置为评估模式。
6.目录的创建:
os.makedirs(args.output_dir, exist_ok=True)
创建输出目录,如果目录已存在则不报错。
7.图像处理和保存:
for image_name in sorted(os.listdir(args.input_dir)):
# ... (图像处理和保存的逻辑)
遍历输入目录中的图像文件。
跳过不是 .jpg
, .png
, .bmp
, .tiff
格式的文件。
调用 load_image
函数加载图像。
使用生成器模型进行推理,得到漫画风格图像。
将生成的图像保存到输出目录。
三、 系统详细设计
3.1 爬虫
我们设计爬虫是为了获取模型测试集,以方便调整模型参数得到更好的训练效果。按照聚焦网络爬虫的简要流程图,我们开发设计了一个可以根据关键词分类进行检索的爬虫。
3.1.1 对爬取目标的定义和描述
依据爬取需求定义好该聚焦网络爬虫爬取的目标及进行相关描述。
classnames = ['super_star', 'cartoon'] # 明星和动漫人物
keypoints = ['%E6%98%8E%E6%98%9F', '%E5%8A%A8%E6%BC%AB%E4%BA%BA%E7%89%A9'] # 关键字对应
gender_file_path = ['male', 'female'] # 对于人的分类检索可以按性别筛选
因为模型需要的是一组真实人物的图片和一组动漫人物的图片,因此我们用classname类定义两个字符串,分别是super_star(明星)和cartoon(动漫人物),假设我们对这两类进行分类检索。
每个字符串都是一个URL编码后的关键字,分别表示“明星”和“动漫人物”,在实际应用中,这些关键字通常是来自于对应网站检索功能的参数用于进行检索。
3.1.2 获取URL
以关键词“明星”为例,搜索后的网址为https://www.vcg.com/creative-image/mingxing/,
按F12查看源码,根据初始的URL爬取页面,并获得新的URL。如下图所示。
检索词为“明星”部分网页源码
检索词为“明星”部分网络参数
其中一个网页图片的完整URL
因为聚焦网络爬虫对网页的爬取是有目的性的,所以与目标无关的网页将会被过滤掉。同时,也需要将已爬取的URL地址存放到一个URL列表中,用于去重和判断爬取的进程。将过滤后的链接放到URL队列中。对聚焦网络爬虫来说,不同的爬取顺序可能导致爬虫的执行效率不同,因此需要依据搜索策略来确定下一步需要爬取哪些URL地址。
for page in range(1, all_page + 1):
num_in_page = 1
# 获得url链接,这里额外增加了筛选条件:图片中仅有一个人
url = 'https://www.vcg.com/creative/search?phrase=' + keypoints[
class_index] + '&creativePeopleNum=2&creativeGender=' + str(gender_index + 1) + '&page=' + str(
page)
header = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.81 Safari/537.36',
# 'User-Agent':'Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US; rv:1.9.1.6) Gecko/20091201 Firefox/3.5.6'
}
req = request.Request(url=url, headers=header)
openhtml = request.urlopen(req).read().decode('utf8')
# 正则表达式
com = re.compile('"url800":.*?/creative/.*?.jpg"')
# 匹配URl地址
urladds = com.findall(openhtml)
以上代码是截取了一部分URL的代码,通过循环遍历每一页的内容,要求构建向服务器发送请求以获取特定条件下图片资源的URL,随后构建一个HTTP请求对象,请求对象包含了要请求的URL和自定义的请求头部信息,随后发送该请求并读取响应数据,并进行解码,将其转换为utf-8编码的字符串。
for urladd in urladds:
# try ... except防止匹配出错后程序停止
try:
add = 'http:' + urladd.strip('"url800":')
# 获取文件名称,格式:vcg+性别+获取方式+page+page中的第几张图片,vcg_raw代表原vcg网站对性别分类
filename = classnames[class_index] + '_' + gender_file_path[
gender_index] + '_vcg_raw_page' + str(page) + '_' + str(num_in_page) + '.jpg'
path = get_path(classnames[class_index], gender_file_path[gender_index], filename)
print('当前下载...', filename)
dom = request.urlopen(add).read()
with open(path, 'wb') as f:
f.write(dom)
sum_all_num += 1
num_in_page += 1
except:
print('当前该任务总共总共下载:', sum_all_num) # 监控进度
if sum_all_num % 50 == 0: # 监控进度
print('当前该任务总共下载:', sum_all_num)
随后构建一个完整的图片URL,并确定生成的图片文件名,包括类别、性别、来源、页面和图片在当前页中的序号。我们使用urllib.request.urlopen方法下载图片,并将其保存到本地指定的路径中,在下载过程中通过打印信息来监控下载进度,每下载50张输出一次当前下载总量。
com = re.compile('"url800":.*?/creative/.*?.jpg"')
下面详细解释一下正则表达式匹配HTML页面中图片的URL地址。”url800”:表示匹配以此开头的字符串; .*?表示匹配了任意数量的字符,这里的.表示匹配除换行符之外的任意字符,表示匹配前面的字符零次或多次,?表示尽可能少地匹配,这里的作用是匹配任意数量的字符; /creative/表示表达式匹配了/creative/这部分固定的字符串; .*?.jpg再次使用了.*?匹配任意数量的字符,直到.jpg*结尾。
3.1.3 保存
def get_path(classname,subclassname,filename):
#获取当前工作路径
cwd = os.getcwd()
#获取图像的保存目录
dir_path = cwd+'/vcg_test/' + classname +'/'+ subclassname
#目录是否存在,不存在则创建目录
if os.path.exists(dir_path):
pass
else:
os.makedirs(dir_path)
#获取图像的绝对路径
file_path = dir_path +'/'+ filename
return file_path
使用os模块的getcwd函数获取当前工作路径,即运行该代码的当前目录。并
根据类别名称和子类别名称拼接出图像的保存目录。这里假设图像保存的根目录是vcg_test,在根目录下再按照类别和子类别进行保存。通过os.path.exists判断图像保存的目录是否存在,如果不存在则通过os.makedirs创建。随后将图像的保存路径和文件名拼接起来,生成图像的绝对路径。返回图像文件的绝对路径。
3.2 前端
3.2.1 数据模型设计
根据前期设想并制定的目标,我们需要实现用户可以通过浏览器页面上传图片,并进行相应的动漫图像生成。理论上在用户点击上传文件后,利用Django框架后台可以直接调用模型对图像文件进行处理,然后将风格转换后的图片效果反馈给用户,不涉及对数据库的相关操作,但是这样会大大降低系统后续的可开发性,导致一些个针对用户个性化的内容和其他高级功能无法设计实现。为避免上述情况的发生,我们希望可以尽可能完整的实现系统,提前为系统将来的拓展开发做准备,因此将对数据库的使用考虑在设计过程中。
3.2.2 系统功能设计
本人在系统功能设计部分主要负责登录、注册页面设计,故在此处仅详细写了这两部分,其他由队友完成的部分仅贴了相关功能示意图和一些必要的说明。
3.2.2.1 图像文件的上传与收集
3.2.2.2 动漫图像的生成与反馈
3.2.2.3 动漫人像的展示
3.2.2.4 动漫风景的展示
模型根据不同的权重分为风景动漫迁移和人脸动漫迁移两类,目前已有的包括celeba_distill.pt和paprika.pt两个风景动漫迁移模型和face_paint_512_v1.pt和face_paint_512_v2.pt两个人脸动漫迁移模型。有必要说明的是,本系统暂时只支持使用face_paint_512_v2.pt模型进行动漫图像生成,一方面是因为系统初始目标是动漫人脸的生成,另一方面在使用大量图片进行测试比较之后,我们发现face_paint_512_v2.pt对风景图像进行动漫风格迁移的效果尚可。但本系统为了展现模型原始的风格迁移效果,在动漫风景展示页面所展示的图像都准备使用风景动漫迁移模型进行生成。
动漫风景展示页面会提供正式的图像下载接口,人脸图像展示页没有提供下载接口,是考虑到可能设计隐私方面的问题,因此暂时作罢。页面的展示效果需要以大图的方式展现,以丰富系统的功能。
3.2.2.5 注册与登录
在登录功能实现后,就可以对系统的一些功能页面设置访问权限了,这可以通过设置中间件的方法实现,设置的中间件将对用户从前端页面发起所有的请求进行处理,当系统后台无法从数据库的session中获取请求体相应的用户信息时认为用户为进行登录,于是拒绝请求继续访问后台,具体实现过程将在系统实现中说明。
3.2.3 系统架构设计
3.2.3.1 功能模块化
目前系统要实现的功能主要分为两大板块,分别是登录注册功能板块和动漫图像生成展示板块,考虑短时间内动漫图像的生成与展示部分功能较少,所以合并在同一文件内进行设计。登录注册部分预期设计实现的功能模块有登录模块、登出模块、注册模块和密码修改模块。动漫图像生成与展示部分预期预计实现的功能模块有首页模块、动漫图像上传与生成模块、动漫图像展示模块。
3.2.3.2 模板复用
项目开发过程中大部分视图函数会通过html模板文件向浏览器返回显示相关的内容,这些模板文件使用相同的编写规则,这导致很多时候模板设计过程中会重复编写一些模板代码,为了避免出现这种代码冗余的情况,需要充分利用模板可复用的特性。目前主要实现模板复用的方式有三种,分别是宏、继承和包含。宏类似函数,可以传入参数,需要定义、调用;继承本质是代码替换,一般用来实现多个页面中重复不变的区域;包含是直接将目标模板文件整个渲染出来。模板复用能很大程度上提高开发效率,减轻开发负担。
3.2.3.3 模型独立化
系统开发的目的是向用户开放模型的使用和向用户展示模型训练的效果,使模型独立化有利于系统的迁移使用和模型的更换。在系统开发过程中,将尽可能的避免模型的嵌入,主要以导入或调用的方式将模型代入系统。
3.2.4 前端交互界面设计
目前系统前端交互界面主要包括用户进行动漫图像生成页面、登录注册页面和图像展示页面,其中动漫图像生成页面的预设如下图所示。
四、 系统实现
4.1 爬虫
我们运行爬虫代码,可以自动生成如下的文件夹框架,用于分别保存明星和动漫人物的男性、女性照片,代码运行示例图如下图所示:
文件夹框架图
代码运行示例图
动漫图像女性示例图
明星图像男性示例图
4.2 前端
4.2.1 Django项目环境搭建
4.2.1.1 虚拟环境
系统应用程序的开发在Conda虚拟环境中进行,基于模型以及现有数据库MySQL的版本进行相关软件包的组合搭配,具体见上文系统软件包的说明。使用Pycharm在该虚拟环境中完成Django项目的初步创建和配置。
4.2.1.2 数据库
数据库使用5.7.41版本的MySQL,版本较老,需要降低Django的版本。数据库的创建可以通过管理员CMD命令实现,本人电脑中配备了MySQL可视化管理工具Navicat,所以最终选择利用该工具完成数据库的创建,数据库命名为animegandb,具体如下图所示。另外在Django项目文件settings.py中,默认连接的是sqlite3,所以要手动配置连接建立的MySQL。
4.2.1.3 项目准备
项目根目录中需要另外新建media目录用于存放用户上传的媒体文件,由于Django中存在相关的保护机制,用户无法直接访问获取media目录中需要的内容,有必要要在settings.py和urls.py中完成路径参数配置。
4.2.2 登录与注册
4.2.2.1 实现逻辑
登录和注册功能的实现分别对应login和signup视图函数,两个视图函数分别向对应的模板传递参数,用户点击页面上的交互按钮后前端页面将数据表单送到对应的视图函数处理。
登录视图函数login分别对GET和POST两种请求方式做出不同的处理,当请求方式为GET时,将创建的LoginForm对象传给模板文件进行前端渲染,效果为用户通过连接访问登录页时页面显示出对应的样式和登录框组。当请求方式为POST时,即用户点击了登录按钮后页面模板将包裹在