程序简介

项目调用tensorflow.keras搭建超分辨率生成对抗网络来提高图片分辨率,训练用的数据集则是500张图片
程序输入:60x60的图片
程序输出:120x120的图片

超分辨率生成对抗网络(SRGAN):从其低分辨率(LR)对应物估计高分辨率(HR)图像的极具挑战性的任务被称为超分辨率(SR)。SRGAN是一种用于图像超分辨率(SR)的生成对抗网络(GAN)。

程序/数据集下载

点击进入下载地址

图片迭代器 Module/Collect.py

导入模块、路径

# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
import os

#路径目录
baseDir = ''#根目录
staticDir = os.path.join(baseDir,'Static')#静态文件目录
resultDir = os.path.join(baseDir,'Result')#结果文件目录
imgsDir = staticDir+'/图片'#图片目录
names = os.listdir(imgsDir)#图片名集合

图片增强函数,随机让图片旋转0-360度,测试查看图片旋转效果

def augment(name):
    '''
    读取图片,并做随机旋转操作,返回图片矩阵
    name:图片名
    '''
    imgPath = imgsDir + '/' + name#图片路径
    #图片矩阵
    img = cv2.imdecode(np.fromfile(imgsDir+'/'+name, dtype=np.uint8),-1)
    #旋转图片矩阵
    img = np.rot90(img,k=np.random.randint(4))
    return img
#同一张图片名输入两次,得到两张不同图片
img1 = augment(names[1])
img2 = augment(names[1])
combine = np.concatenate((img1,img2), axis=1)
plt.matshow(combine)

图片缩小函数,原图为网络期望输出,缩小图片为网络输入,测试查看效果

def reduce(img):
    '''
    缩小图片的长宽为原来的一半,返回小图
    img:图片矩阵
    '''
    #采用双线性插值算法缩小图片
    miniImg = cv2.resize(img,(int(img.shape[0]/2),int(img.shape[1]/2)), interpolation=cv2.INTER_LINEAR)
    return miniImg
img3 = reduce(img1)
plt.matshow(img1)
plt.title('原图')
plt.matshow(img3)
plt.title('缩小图')

因为图片像素区间为[0,255],而神经网络输入输出的区间最好是[-1,1],所以需要下文的函数对像素值进行归一化和还原操作,测试查看效果

def normalizeImg(img):
    '''将图片归一化到-1,1,这里可以有小数'''
    img = (img/255 - 0.5)*2
    return img

def reverseImg(img):
    '''将图片还原到原数量级,这里不能有小数'''
    img = (img + 1)*255/2
    img = img.astype(np.uint8)
    return img
print('原数',[0,255],'归一化后',normalizeImg(np.array([0,255])),'还原后',reverseImg(np.array([-1,1])))
原数 [0, 255] 归一化后 [-1.  1.] 还原后 [  0 255]

图片迭代器,调用上文定义的函数,每次调用都会随机抽取批处理量的图片,并且对图片进行随机增强的操作函数,返回的数据为DataFrame,分为4列,未归一化的输入输出集,归一化的输入输出集

def collect(batchSize):
    '''
    随机批量抽取图片作为训练输入输出
    batchSize:批量大小
    '''
    #随机选择batch张图片
    choosNames = np.random.choice(names,batchSize,replace=False)
    data = pd.DataFrame({'name':choosNames})
    
    #原输出集
    data['output'] = data['name'].apply(augment)
    #原输入集
    data['input'] = data['output'].apply(reduce)
    #归一化输入集
    data['normalInput'] = data['input'].apply(normalizeImg)
    #归一化输出集
    data['normalOutput'] = data['output'].apply(normalizeImg)
    return data

搭建SRGAN框架 Module/BuileNet.py

导入模块

# -*- coding: utf-8 -*-
from tensorflow.keras.layers import Input,Dense,Conv2D,Flatten,BatchNormalization,UpSampling2D
from tensorflow.keras.layers import PReLU,Add,Concatenate,LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam,RMSprop
from tensorflow.keras.losses import mean_squared_error as mse
from tensorflow.keras.losses import mean_absolute_error as mae
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input
import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np

生成器构建函数,即将60x60的图片超分为120x120的图片的神经网络,其中比较重要的结构被称为残差块,即程序中的resBlock函数,这里没配置损失函数和优化器,是因为生成器的训练在下文的对抗网络训练过程中

def resBlock(xIn,filterNum):
    '''残差块'''
    x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(xIn)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([xIn, x])
    x = LeakyReLU()(x)
    return x

def createGenerator(layerNum,filterNum):
    '''
    创建生成器
    layerNum:残差块数
    filterNum:残差块卷积核数
    '''
    #输入层
    inputLayer = Input(shape=(None,None,3))
    
    #第一层
    firstLayer = Conv2D(filters=filterNum,kernel_size=3,padding='same')(inputLayer)
    firstLayer = BatchNormalization()(firstLayer)
    firstLayer = LeakyReLU()(firstLayer)

    #中间层    
    middle = firstLayer
    for num in range(layerNum):
        middle = resBlock(middle,filterNum)   
    middle = Conv2D(filters=filterNum,kernel_size=3,padding='same')(middle)
    middle = BatchNormalization()(middle)
    middle = LeakyReLU()(middle)
    middle = Add()([firstLayer,middle])
    middle = UpSampling2D(size=2)(middle)
    
    #输出层
    outputLayer = Conv2D(filters=3,kernel_size=9,padding="same",activation='tanh')(middle)
    
    #建模
    model = Model(inputs=inputLayer,outputs=outputLayer)
    return model

判别器构建函数,即判断生成器生成的高清图片是否为真实图片,判别器差不多就是普通的分类卷积神经网络,输出在[0,1]区间,损失函数则是二分类损失

def block(xIn,filterNum):
    '''卷积+标准化+激活块'''
    x = Conv2D(filterNum,kernel_size=3,strides=3,padding='same')(xIn)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    return x

def createDiscriminator(layerNum,filterNum,lr):
    '''
    创建判别器
    layerNum:中间块数
    filterNum:中间块卷积核数
    lr:学习率
    '''    
    #输入层
    inputLayer = Input(shape=(120,120,3))
    
    #中间层
    middle = inputLayer
    for num in range(layerNum):
        middle = block(middle,filterNum)  
    middle = Flatten()(middle)
    middle = Dense(1000)(middle)
    middle = LeakyReLU()(middle)
    
    #输出层
    outputLayer = Dense(1, activation='sigmoid')(middle)
    
    #建模
    model = Model(inputs=inputLayer,outputs=outputLayer)
    #优化器
    optimizer = RMSprop(lr=lr)
    model.compile(optimizer=optimizer, loss='binary_crossentropy')
    return model

构建对抗网络,即组合生成器和判别器,形成新的网络,构成这个网络的原因是为了训练生成器,生成器的目的就是迷惑判别器,组合后的网络先将判别器部分的参数固定,然后训练生成器部分的参数,使得判别器分不清真实和生成图片

注意,对抗网络的输入和输出都在[-1,1]的区间内,而在计算内容损失时需要将图片还原,所以这里定义一个reverseImg图片还原函数,内容损失是对抗网络的损失函数

def reverseImg(img):
    '''将图片还原到原数量级'''
    img = (img + 1)*255/2
    return img
print('处理前',np.array([-1,1]),'处理后',reverseImg(np.array([-1,1])))
处理前 [-1  1] 处理后 [  0. 255.]

对抗网络的输出有两部分,第一部分为组合判别网判断生成图片是否为真实图片,第二部分为组合生成器生成的图片

与之对应的损失函数也有两部分,第一部分为组合后判别器的二分类损失,第二部分为内容损失,内容损失这里不是将原高清图片与生成图片进行MSE计算,而是需要用VGG19网络进行特征提取,然后对两张图片的特征进行MSE计算

#特征提取器
vgg19 = VGG19(include_top=False, weights='imagenet')
vgg19 = Model(vgg19.input, vgg19.output)

def contentLoss(y_true, y_pred):
    '''内容损失'''
    y_true = reverseImg(y_true)
    y_pred = reverseImg(y_pred)
    y_true = preprocess_input(y_true)
    y_pred = preprocess_input(y_pred)
    sr = vgg19(y_pred)
    hr = vgg19(y_true)
    return mse(y_true, y_pred)

def createGan(generator,discriminator,lr):
    '''构建对抗网'''
    discriminator.trainable = False
    #生成器输入
    lowImg = generator.input
    #生成器输出
    fakeHighImg = generator(lowImg)
    #生成器判断
    judge = discriminator(fakeHighImg)
    model = Model(inputs=lowImg,outputs=[judge,fakeHighImg])
    optimizer = RMSprop(lr=lr)
    model.compile(optimizer=optimizer, loss=['binary_crossentropy', contentLoss],loss_weights=[1, 1e-1])
    model.summary()
    return model

实例化生成器、判别器、对抗网络

generator = createGenerator(5,100)#生成器
discriminator = createDiscriminator(10,100,1e-5)#判别器
print('打印内容为对抗网络结构')
gan = createGan(generator,discriminator,1e-5)#对抗网络
打印内容为对抗网络结构
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, None, None, 3)     0         
_________________________________________________________________
model_10 (Model)             (None, None, None, 3)     1023003   
_________________________________________________________________
model_11 (Model)             (None, 1)                 919701    
=================================================================
Total params: 1,942,704
Trainable params: 1,020,603
Non-trainable params: 922,101
_________________________________________________________________

训练网络,查看效果 Main.py

导入模块、路径、预设参数

# -*- coding: utf-8 -*-
from Module.BuidModel import createGenerator,createDiscriminator,createGan
from Module.Collect import collect,reverseImg
import cv2
import numpy as np
import os

############################可调整参数##########################
batchSize = 10#批处理量
genFilters = 100#生成器核数
disFilters = 100#判别器核数
genLayers = 5#生成器残差块层数
disLayers = 10#判别器卷积块数
genLearnRate = 5e-5#生成学习率
disLearnRate = 1e-4#判别学习率
##############################################################

#路径目录
baseDir = ''#当前目录
staticDir = os.path.join(baseDir,'Static')#静态文件目录
resultDir = os.path.join(baseDir,'Result')#结果文件目录

实例化生成器、判别器、对抗网络,效果与上文的实例化演示一致

generator = createGenerator(genLayers,genFilters)#生成器
discriminator = createDiscriminator(disLayers,disFilters,disLearnRate)#判别器
gan = createGan(generator,discriminator,genLearnRate)#对抗网络

进入训练的无限循环,每个epoch随机抽取图片,首先用生成器从低清图生成到伪高清图,然后将伪高清图和原高清图作为输入到判别器

判别器的训练目的在于给伪高清图打标签0,给原高清图片打标签1

最后组合生成器和判别器,固定判别器的参数,训练对抗网络(即生成器的参数),目的是让生成器混淆判别器的识别能力,使生成器的生成图片尽可能的被判别器打为标签1

每100个epoch保存效果图,下文查看效果

epochs = 0#迭代次数
loss = {'dLoss':[],'gLoss':[],'cLoss':[]}
while True:
    epochs += 1
    imgs = collect(batchSize)#随机抽取图片
    #归一化的输入
    try:
        low = np.array(imgs['normalInput'].values.tolist()).reshape(batchSize,60,60,3)
    except:
        epochs -= 1
        print('error')
        continue
    #生成高清图(-1,1)
    fakeHigh = generator.predict(low)
    #原高清图(-1,1)
    realHigh = np.array(imgs['normalOutput'].values.tolist()).reshape(-1,120,120,3)
    #真伪标签
    realBool = np.random.uniform(0.7,1,size=(batchSize,))
    fakeBool = np.random.uniform(0,0.3,size=(batchSize,))

    #鉴别器训练
    discriminator.trainable = True
    dRealLoss = discriminator.train_on_batch(x=realHigh, y=realBool)
    dFakeLoss = discriminator.train_on_batch(x=fakeHigh, y=fakeBool)
    #判别损失
    loss['dLoss'].append(0.5 * (dRealLoss + dFakeLoss))
    
    #生成器训练
    discriminator.trainable = False
    ganLoss = gan.train_on_batch(x=low, y=[realBool,realHigh])
    #对抗损失
    loss['gLoss'].append(ganLoss[1])
    #内容损失
    loss['cLoss'].append(ganLoss[2])   
    
    if epochs%100==0:
        #打印损失
        dLoss = np.array(loss['dLoss'][-100:]).mean()
        gLoss = np.array(loss['gLoss'][-100:]).mean()
        cLoss = np.array(loss['cLoss'][-100:]).mean()
        print('epoch:%d dLoss:%.4f gLoss:%.4f cLoss:%.4f'%(epochs,dLoss,gLoss,cLoss))
        #保存模型
        generator.save_weights(resultDir+'/generator.h5')
        discriminator.save_weights(resultDir+'/discriminator.h5')
        #原低清图(-1,1)
        lowImg = low[0]
        #生成高清图
        fakeHigh = generator.predict(lowImg[np.newaxis,:]).reshape((120,120,3))
        fakeHigh = reverseImg(fakeHigh)       
        #传统双线性插值放大
        lineHigh = cv2.resize(reverseImg(lowImg),(120,120), interpolation=cv2.INTER_LINEAR)
        #原图像
        originHigh = reverseImg(realHigh[0])
        #组合图片
        combine = np.concatenate((lineHigh,fakeHigh,originHigh), axis=1)
        cv2.imencode('.jpg',combine)[1].tofile(resultDir+'/compare.jpg')

从左往右数,图1为传统的双线性插值法的超分结果,图2为超分对抗网络的超分结果,图3为原高清图

posted on 2020-03-16 15:51  爆米LiuChen  阅读(3414)  评论(6编辑  收藏  举报