pytorch实现style transfer



# -*- coding: utf-8 -*-
from __future__ import division
from torch.backends import cudnn
from torch.autograd import Variable
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

# Load image file and convert it into variable
# unsqueeze for make the 4D tensor to perform conv arithmetic
def load_image(image_path, transform=None, max_size=None, shape=None):
    image =
    if max_size is not None:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    if shape is not None:
        image = image.resize(shape, Image.LANCZOS)
    if transform is not None:
        image = transform(image).unsqueeze(0)
    return image.type(dtype)

# Pretrained VGGNet 
class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__() = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
    def forward(self, x):
        """Extract 5 conv activation maps from an input image.
            x: 4D tensor of shape (1, 3, height, width).
            features: a list containing 5 conv activation maps.
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)  # 依次输出并且截取
            if name in
        return features

def main(config):
    # Image preprocessing
    # For normalization, see
    transform = transforms.Compose([
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    # Load content and style images
    # make content.size() == style.size() 
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(, transform, shape=[content.size(2), content.size(3)])
    # Initialization and optimizer
    target = Variable(content.clone(), requires_grad=True)
    optimizer = torch.optim.Adam([target],, betas=[0.5, 0.999])  # 优化目标:target,非网络参数
    # 优化的目标并不是网络参数,而是target,target与content要相似,并且target的grammer矩阵要和style的feature要相似
    vgg = VGGNet()
    if use_cuda:
    for step in range(config.total_step):
        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)   # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
        content_features = vgg(Variable(content))
        style_features = vgg(Variable(style))

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss (target and content image)
            content_loss += torch.mean((f1 - f2)**2)  # 用内容来直接衡量loss

            # Reshape conv features
            _, c, h, w = f1.size()  # channel height width
            f1 = f1.view(c, h * w)  # reshape a vector, c是feature中的通道数目
            f3 = f3.view(c, h * w)  # reshape a vector

            # Compute gram matrix  
            f1 =, f1.t())
            f3 =, f3.t())

            # Compute style loss (target and style image)
            style_loss += torch.mean((f1 - f3)**2) / (c * h * w)   # 用两个gram矩阵之间的距离来衡量loss
            # style_loss += torch.mean((f1 - f3) ** 2)   # 用两个gram矩阵之间的距离来衡量loss
            # / c * h * w 保证二者之间的尺度相同
        # Compute total loss, backprop and optimize
        loss = content_loss + config.style_weight * style_loss 

        if (step+1) % config.log_step == 0:
            print ('Step [%d/%d], Content Loss: %.4f, Style Loss: %.4f' 
                   %(step+1, config.total_step,[0],[0]))
        if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().cpu().squeeze()
            img = denorm(, 1)
            torchvision.utils.save_image(img, 'output2-%d.png' %(step+1))
        if step == 0:
            for name, weight in vgg.state_dict().items():
                print(name, weight)
        if step == 21:
            for name, weight in vgg.state_dict().items():
                print(name, weight)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='./png/content.png')
    parser.add_argument('--style', type=str, default='./png/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=5000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)
    config = parser.parse_args()
代码已经写的很清晰干净,大致思路是作者用gram矩阵来衡量一幅作品的style,用未经过变化的feature map来衡量content的相似度




3.这种idea类似于gan,定义两个loss,让网络的输出在这两个loss之间做一个权衡,gan是真假二元对立,style transfer是style与content的互补关系




What that means is that in each layer we take paires of feature maps we just multiply them pointwise and sum over the spatial extent and that gives us one correlation value . that give us a number that tells me how correlated are these two features in response to that particular image and if we do all the feature maps in the layer we get a whole correlation matrix and we can do that for multiple layers of the network and then have a texture model with which we model the style, and we can now visualize what information is captured by these curl feature map correlations by generating a new image that now doesn't reproduce the extra feature responses in the network but rather reproduce the correlations between feature responses










Step [4800/5000], Content Loss: 80.6877, Style Loss: 487.1260
Step [4810/5000], Content Loss: 80.7026, Style Loss: 486.3426
Step [4820/5000], Content Loss: 80.7246, Style Loss: 485.5407
Step [4830/5000], Content Loss: 80.7389, Style Loss: 484.7999
Step [4840/5000], Content Loss: 80.7581, Style Loss: 484.0170
Step [4850/5000], Content Loss: 80.7689, Style Loss: 483.3676
Step [4860/5000], Content Loss: 80.7863, Style Loss: 482.5543
Step [4870/5000], Content Loss: 80.8044, Style Loss: 481.7750
Step [4880/5000], Content Loss: 80.8238, Style Loss: 481.0148
Step [4890/5000], Content Loss: 80.8351, Style Loss: 480.3542
Step [4900/5000], Content Loss: 80.8555, Style Loss: 479.5468
Step [4910/5000], Content Loss: 80.8675, Style Loss: 478.8164
Step [4920/5000], Content Loss: 80.8801, Style Loss: 478.0892
Step [4930/5000], Content Loss: 80.8959, Style Loss: 477.3603
Step [4940/5000], Content Loss: 80.9091, Style Loss: 476.5545
Step [4950/5000], Content Loss: 80.9220, Style Loss: 475.8526
Step [4960/5000], Content Loss: 80.9387, Style Loss: 475.0789
Step [4970/5000], Content Loss: 80.9549, Style Loss: 474.3080
Step [4980/5000], Content Loss: 80.9718, Style Loss: 473.5461
Step [4990/5000], Content Loss: 80.9910, Style Loss: 472.7648
Step [5000/5000], Content Loss: 81.0086, Style Loss: 471.9411
Step [4800/5000], Content Loss: 90.9087, Style Loss: 259023120.0000
Step [4810/5000], Content Loss: 90.9227, Style Loss: 258663088.0000
Step [4820/5000], Content Loss: 90.9379, Style Loss: 258286064.0000
Step [4830/5000], Content Loss: 90.9492, Style Loss: 257919232.0000
Step [4840/5000], Content Loss: 90.9621, Style Loss: 257566096.0000
Step [4850/5000], Content Loss: 90.9756, Style Loss: 257201568.0000
Step [4860/5000], Content Loss: 90.9899, Style Loss: 256855216.0000
Step [4870/5000], Content Loss: 91.0036, Style Loss: 256477552.0000
Step [4880/5000], Content Loss: 91.0191, Style Loss: 256127024.0000
Step [4890/5000], Content Loss: 91.0332, Style Loss: 255794784.0000
Step [4900/5000], Content Loss: 91.0459, Style Loss: 255433952.0000
Step [4910/5000], Content Loss: 91.0617, Style Loss: 255089968.0000
Step [4920/5000], Content Loss: 91.0733, Style Loss: 254741968.0000
Step [4930/5000], Content Loss: 91.0890, Style Loss: 254383136.0000
Step [4940/5000], Content Loss: 91.1015, Style Loss: 254048224.0000
Step [4950/5000], Content Loss: 91.1160, Style Loss: 253700608.0000
Step [4960/5000], Content Loss: 91.1338, Style Loss: 253371728.0000
Step [4970/5000], Content Loss: 91.1449, Style Loss: 253062304.0000
Step [4980/5000], Content Loss: 91.1602, Style Loss: 252754048.0000
Step [4990/5000], Content Loss: 91.1777, Style Loss: 252447952.0000
Step [5000/5000], Content Loss: 91.1934, Style Loss: 252113584.0000
其实可以看出,style的loss很高,所以学习得到的style很多,上面的style loss大的图片显然学习到了更多的style 




# -*- coding: utf-8 -*-
from __future__ import division
from torch.backends import cudnn
from torch.autograd import Variable
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

# Load image file and convert it into variable
# unsqueeze for make the 4D tensor to perform conv arithmetic
def load_image(image_path, transform=None, max_size=None, shape=None):
    image =
    if max_size is not None:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    if shape is not None:
        image = image.resize(shape, Image.LANCZOS)
    if transform is not None:
        image = transform(image).unsqueeze(0)
    return image.type(dtype)

# Pretrained VGGNet 
class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__() = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
    def forward(self, x):
        """Extract 5 conv activation maps from an input image.
            x: 4D tensor of shape (1, 3, height, width).
            features: a list containing 5 conv activation maps.
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)  # 依次输出并且截取
            if name in
        return features

def main(config):
    # Image preprocessing
    # For normalization, see
    transform = transforms.Compose([
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    # Load content and style images
    # make content.size() == style.size() 
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(, transform, shape=[content.size(2), content.size(3)])
    # Initialization and optimizer
    target = Variable(content.clone(), requires_grad=True)
    optimizer = torch.optim.Adam([target],, betas=[0.5, 0.999])  # 优化目标:target,非网络参数
    # 优化的目标并不是网络参数,而是target,target与content要相似,并且target的grammer矩阵要和style的feature要相似
    vgg = VGGNet()
    if use_cuda:
    for step in range(config.total_step):
        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)   # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
        content_features = vgg(Variable(content))
        style_features = vgg(Variable(style))

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss (target and content image)
            content_loss += torch.mean((f1 - f2)**2)  # 用内容来直接衡量loss

            # Reshape conv features
            _, c, h, w = f1.size()  # channel height width
            f1 = f1.view(c, h * w)  # reshape a vector, c是feature中的通道数目
            f3 = f3.view(c, h * w)  # reshape a vector

            # Compute gram matrix  
            f1 =, f1.t())
            f3 =, f3.t())
            _, lam1, _, = torch.svd(f1)
            _, lam2, _, = torch.svd(f3)
            cond1 = lam1[1]/lam1[-1]
            cond2 = lam2[2]/lam2[-1]
            # Compute style loss (target and style image)
            # style_loss += torch.mean((f1 - f3)**2) / (c * h * w)   # 用两个gram矩阵之间的距离来衡量loss
            # style_loss += torch.mean((f1 - f3) ** 2)   # 用两个gram矩阵之间的距离来衡量loss
            style_loss += 0.001*torch.mean((lam1-lam2)**2)**(0.5)  # 用作特征值作为衡量标准
            # / c * h * w 保证二者之间的尺度相同
        # Compute total loss, backprop and optimize
        loss = content_loss + config.style_weight * style_loss 

        if (step+1) % config.log_step == 0:
            print('Step [%d/%d], Content Loss: %.4f, Style Loss: %.4f'
                   %(step+1, config.total_step,[0],[0]))
        if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().cpu().squeeze()
            img = denorm(, 1)
            torchvision.utils.save_image(img, 'output-%d.png' %(step+1))
        if step == 0:
            for name, weight in vgg.state_dict().items():
                print(name, weight)
        if step == 21:
            for name, weight in vgg.state_dict().items():
                print(name, weight)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='./png/content.png')
    parser.add_argument('--style', type=str, default='./png/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=5000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)
    config = parser.parse_args()
同样的贵系关于maniforld learning也有内容相关的中文讲义



