程序简介
项目调用tensorflow.keras搭建超分辨率生成对抗网络来提高图片分辨率,训练用的数据集则是500张图片
程序输入:60x60的图片
程序输出:120x120的图片
超分辨率生成对抗网络(SRGAN):从其低分辨率(LR)对应物估计高分辨率(HR)图像的极具挑战性的任务被称为超分辨率(SR)。SRGAN是一种用于图像超分辨率(SR)的生成对抗网络(GAN)。
程序/数据集下载
图片迭代器 Module/Collect.py
导入模块、路径
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
import os
#路径目录
baseDir = ''#根目录
staticDir = os.path.join(baseDir,'Static')#静态文件目录
resultDir = os.path.join(baseDir,'Result')#结果文件目录
imgsDir = staticDir+'/图片'#图片目录
names = os.listdir(imgsDir)#图片名集合
图片增强函数,随机让图片旋转0-360度,测试查看图片旋转效果
def augment(name):
'''
读取图片,并做随机旋转操作,返回图片矩阵
name:图片名
'''
imgPath = imgsDir + '/' + name#图片路径
#图片矩阵
img = cv2.imdecode(np.fromfile(imgsDir+'/'+name, dtype=np.uint8),-1)
#旋转图片矩阵
img = np.rot90(img,k=np.random.randint(4))
return img
#同一张图片名输入两次,得到两张不同图片
img1 = augment(names[1])
img2 = augment(names[1])
combine = np.concatenate((img1,img2), axis=1)
plt.matshow(combine)
图片缩小函数,原图为网络期望输出,缩小图片为网络输入,测试查看效果
def reduce(img):
'''
缩小图片的长宽为原来的一半,返回小图
img:图片矩阵
'''
#采用双线性插值算法缩小图片
miniImg = cv2.resize(img,(int(img.shape[0]/2),int(img.shape[1]/2)), interpolation=cv2.INTER_LINEAR)
return miniImg
img3 = reduce(img1)
plt.matshow(img1)
plt.title('原图')
plt.matshow(img3)
plt.title('缩小图')
因为图片像素区间为[0,255],而神经网络输入输出的区间最好是[-1,1],所以需要下文的函数对像素值进行归一化和还原操作,测试查看效果
def normalizeImg(img):
'''将图片归一化到-1,1,这里可以有小数'''
img = (img/255 - 0.5)*2
return img
def reverseImg(img):
'''将图片还原到原数量级,这里不能有小数'''
img = (img + 1)*255/2
img = img.astype(np.uint8)
return img
print('原数',[0,255],'归一化后',normalizeImg(np.array([0,255])),'还原后',reverseImg(np.array([-1,1])))
原数 [0, 255] 归一化后 [-1. 1.] 还原后 [ 0 255]
图片迭代器,调用上文定义的函数,每次调用都会随机抽取批处理量的图片,并且对图片进行随机增强的操作函数,返回的数据为DataFrame,分为4列,未归一化的输入输出集,归一化的输入输出集
def collect(batchSize):
'''
随机批量抽取图片作为训练输入输出
batchSize:批量大小
'''
#随机选择batch张图片
choosNames = np.random.choice(names,batchSize,replace=False)
data = pd.DataFrame({'name':choosNames})
#原输出集
data['output'] = data['name'].apply(augment)
#原输入集
data['input'] = data['output'].apply(reduce)
#归一化输入集
data['normalInput'] = data['input'].apply(normalizeImg)
#归一化输出集
data['normalOutput'] = data['output'].apply(normalizeImg)
return data
搭建SRGAN框架 Module/BuileNet.py
导入模块
# -*- coding: utf-8 -*-
from tensorflow.keras.layers import Input,Dense,Conv2D,Flatten,BatchNormalization,UpSampling2D
from tensorflow.keras.layers import PReLU,Add,Concatenate,LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam,RMSprop
from tensorflow.keras.losses import mean_squared_error as mse
from tensorflow.keras.losses import mean_absolute_error as mae
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input
import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
生成器构建函数,即将60x60的图片超分为120x120的图片的神经网络,其中比较重要的结构被称为残差块,即程序中的resBlock函数,这里没配置损失函数和优化器,是因为生成器的训练在下文的对抗网络训练过程中
def resBlock(xIn,filterNum):
'''残差块'''
x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(xIn)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(x)
x = BatchNormalization()(x)
x = Add()([xIn, x])
x = LeakyReLU()(x)
return x
def createGenerator(layerNum,filterNum):
'''
创建生成器
layerNum:残差块数
filterNum:残差块卷积核数
'''
#输入层
inputLayer = Input(shape=(None,None,3))
#第一层
firstLayer = Conv2D(filters=filterNum,kernel_size=3,padding='same')(inputLayer)
firstLayer = BatchNormalization()(firstLayer)
firstLayer = LeakyReLU()(firstLayer)
#中间层
middle = firstLayer
for num in range(layerNum):
middle = resBlock(middle,filterNum)
middle = Conv2D(filters=filterNum,kernel_size=3,padding='same')(middle)
middle = BatchNormalization()(middle)
middle = LeakyReLU()(middle)
middle = Add()([firstLayer,middle])
middle = UpSampling2D(size=2)(middle)
#输出层
outputLayer = Conv2D(filters=3,kernel_size=9,padding="same",activation='tanh')(middle)
#建模
model = Model(inputs=inputLayer,outputs=outputLayer)
return model
判别器构建函数,即判断生成器生成的高清图片是否为真实图片,判别器差不多就是普通的分类卷积神经网络,输出在[0,1]区间,损失函数则是二分类损失
def block(xIn,filterNum):
'''卷积+标准化+激活块'''
x = Conv2D(filterNum,kernel_size=3,strides=3,padding='same')(xIn)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
return x
def createDiscriminator(layerNum,filterNum,lr):
'''
创建判别器
layerNum:中间块数
filterNum:中间块卷积核数
lr:学习率
'''
#输入层
inputLayer = Input(shape=(120,120,3))
#中间层
middle = inputLayer
for num in range(layerNum):
middle = block(middle,filterNum)
middle = Flatten()(middle)
middle = Dense(1000)(middle)
middle = LeakyReLU()(middle)
#输出层
outputLayer = Dense(1, activation='sigmoid')(middle)
#建模
model = Model(inputs=inputLayer,outputs=outputLayer)
#优化器
optimizer = RMSprop(lr=lr)
model.compile(optimizer=optimizer, loss='binary_crossentropy')
return model
构建对抗网络,即组合生成器和判别器,形成新的网络,构成这个网络的原因是为了训练生成器,生成器的目的就是迷惑判别器,组合后的网络先将判别器部分的参数固定,然后训练生成器部分的参数,使得判别器分不清真实和生成图片
注意,对抗网络的输入和输出都在[-1,1]的区间内,而在计算内容损失时需要将图片还原,所以这里定义一个reverseImg图片还原函数,内容损失是对抗网络的损失函数
def reverseImg(img):
'''将图片还原到原数量级'''
img = (img + 1)*255/2
return img
print('处理前',np.array([-1,1]),'处理后',reverseImg(np.array([-1,1])))
处理前 [-1 1] 处理后 [ 0. 255.]
对抗网络的输出有两部分,第一部分为组合判别网判断生成图片是否为真实图片,第二部分为组合生成器生成的图片
与之对应的损失函数也有两部分,第一部分为组合后判别器的二分类损失,第二部分为内容损失,内容损失这里不是将原高清图片与生成图片进行MSE计算,而是需要用VGG19网络进行特征提取,然后对两张图片的特征进行MSE计算
#特征提取器
vgg19 = VGG19(include_top=False, weights='imagenet')
vgg19 = Model(vgg19.input, vgg19.output)
def contentLoss(y_true, y_pred):
'''内容损失'''
y_true = reverseImg(y_true)
y_pred = reverseImg(y_pred)
y_true = preprocess_input(y_true)
y_pred = preprocess_input(y_pred)
sr = vgg19(y_pred)
hr = vgg19(y_true)
return mse(y_true, y_pred)
def createGan(generator,discriminator,lr):
'''构建对抗网'''
discriminator.trainable = False
#生成器输入
lowImg = generator.input
#生成器输出
fakeHighImg = generator(lowImg)
#生成器判断
judge = discriminator(fakeHighImg)
model = Model(inputs=lowImg,outputs=[judge,fakeHighImg])
optimizer = RMSprop(lr=lr)
model.compile(optimizer=optimizer, loss=['binary_crossentropy', contentLoss],loss_weights=[1, 1e-1])
model.summary()
return model
实例化生成器、判别器、对抗网络
generator = createGenerator(5,100)#生成器
discriminator = createDiscriminator(10,100,1e-5)#判别器
print('打印内容为对抗网络结构')
gan = createGan(generator,discriminator,1e-5)#对抗网络
打印内容为对抗网络结构
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) (None, None, None, 3) 0
_________________________________________________________________
model_10 (Model) (None, None, None, 3) 1023003
_________________________________________________________________
model_11 (Model) (None, 1) 919701
=================================================================
Total params: 1,942,704
Trainable params: 1,020,603
Non-trainable params: 922,101
_________________________________________________________________
训练网络,查看效果 Main.py
导入模块、路径、预设参数
# -*- coding: utf-8 -*-
from Module.BuidModel import createGenerator,createDiscriminator,createGan
from Module.Collect import collect,reverseImg
import cv2
import numpy as np
import os
############################可调整参数##########################
batchSize = 10#批处理量
genFilters = 100#生成器核数
disFilters = 100#判别器核数
genLayers = 5#生成器残差块层数
disLayers = 10#判别器卷积块数
genLearnRate = 5e-5#生成学习率
disLearnRate = 1e-4#判别学习率
##############################################################
#路径目录
baseDir = ''#当前目录
staticDir = os.path.join(baseDir,'Static')#静态文件目录
resultDir = os.path.join(baseDir,'Result')#结果文件目录
实例化生成器、判别器、对抗网络,效果与上文的实例化演示一致
generator = createGenerator(genLayers,genFilters)#生成器
discriminator = createDiscriminator(disLayers,disFilters,disLearnRate)#判别器
gan = createGan(generator,discriminator,genLearnRate)#对抗网络
进入训练的无限循环,每个epoch随机抽取图片,首先用生成器从低清图生成到伪高清图,然后将伪高清图和原高清图作为输入到判别器
判别器的训练目的在于给伪高清图打标签0,给原高清图片打标签1
最后组合生成器和判别器,固定判别器的参数,训练对抗网络(即生成器的参数),目的是让生成器混淆判别器的识别能力,使生成器的生成图片尽可能的被判别器打为标签1
每100个epoch保存效果图,下文查看效果
epochs = 0#迭代次数
loss = {'dLoss':[],'gLoss':[],'cLoss':[]}
while True:
epochs += 1
imgs = collect(batchSize)#随机抽取图片
#归一化的输入
try:
low = np.array(imgs['normalInput'].values.tolist()).reshape(batchSize,60,60,3)
except:
epochs -= 1
print('error')
continue
#生成高清图(-1,1)
fakeHigh = generator.predict(low)
#原高清图(-1,1)
realHigh = np.array(imgs['normalOutput'].values.tolist()).reshape(-1,120,120,3)
#真伪标签
realBool = np.random.uniform(0.7,1,size=(batchSize,))
fakeBool = np.random.uniform(0,0.3,size=(batchSize,))
#鉴别器训练
discriminator.trainable = True
dRealLoss = discriminator.train_on_batch(x=realHigh, y=realBool)
dFakeLoss = discriminator.train_on_batch(x=fakeHigh, y=fakeBool)
#判别损失
loss['dLoss'].append(0.5 * (dRealLoss + dFakeLoss))
#生成器训练
discriminator.trainable = False
ganLoss = gan.train_on_batch(x=low, y=[realBool,realHigh])
#对抗损失
loss['gLoss'].append(ganLoss[1])
#内容损失
loss['cLoss'].append(ganLoss[2])
if epochs%100==0:
#打印损失
dLoss = np.array(loss['dLoss'][-100:]).mean()
gLoss = np.array(loss['gLoss'][-100:]).mean()
cLoss = np.array(loss['cLoss'][-100:]).mean()
print('epoch:%d dLoss:%.4f gLoss:%.4f cLoss:%.4f'%(epochs,dLoss,gLoss,cLoss))
#保存模型
generator.save_weights(resultDir+'/generator.h5')
discriminator.save_weights(resultDir+'/discriminator.h5')
#原低清图(-1,1)
lowImg = low[0]
#生成高清图
fakeHigh = generator.predict(lowImg[np.newaxis,:]).reshape((120,120,3))
fakeHigh = reverseImg(fakeHigh)
#传统双线性插值放大
lineHigh = cv2.resize(reverseImg(lowImg),(120,120), interpolation=cv2.INTER_LINEAR)
#原图像
originHigh = reverseImg(realHigh[0])
#组合图片
combine = np.concatenate((lineHigh,fakeHigh,originHigh), axis=1)
cv2.imencode('.jpg',combine)[1].tofile(resultDir+'/compare.jpg')
从左往右数,图1为传统的双线性插值法的超分结果,图2为超分对抗网络的超分结果,图3为原高清图