GAN生成图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
 
# 设置一些超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
latent_size = 100
hidden_size = 64
image_size = 96*96*3
 
# 定义生成器和判别器的类
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, image_size),
            nn.Tanh()
        )
 
    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), 3, 96, 96)
        return x
 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.model(x)
        return x
 
# 实例化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)
 
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
 
# 加载数据集
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.ImageFolder(root="./d", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
# 开始训练
num_epochs = 100
for epoch in range(num_epochs):
    for idx, (images, _) in enumerate(dataloader):
        batch_size = images.size(0)
        images = images.to(device)
 
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
 
        # 训练判别器
        optimizer_D.zero_grad()
 
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)
        real_score = torch.mean(outputs).item()
 
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        fake_loss = criterion(outputs, fake_labels)
        fake_score = torch.mean(outputs).item()
 
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()
 
        # 训练生成器
        optimizer_G.zero_grad()
 
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
 
        g_loss.backward()
        optimizer_G.step()
 
        # 打印信息
        if idx % 100 == 0:
            print("Epoch [{}/{}], Batch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, Real Score: {:.2f}, Fake Score: {:.2f}".format(
                epoch+1, num_epochs, idx+1, len(dataloader), d_loss.item(), g_loss.item(), real_score, fake_score))
 
    # 保存生成的图像
    if not os.path.exists("./images"):
        os.mkdir("./images")
    if (epoch+1) % 1 == 0:
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(z)
        save_image(fake_images.data[:25], './images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)
torch.save(optimizer_G.state_dict(), './images/netG_%03d.pth' % num_epochs )
torch.save(optimizer_D.state_dict(), './images/netD_%03d.pth' % num_epochs)

  批量移动文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import os
import shutil
 
# 用于处理的文件夹路径
folder_path = "D:\BaiduNetdiskDownload\CASIA-FaceV5 (000-099)"
 
# 列出文件夹中的所有子文件夹
subfolders = next(os.walk(folder_path))[1]
 
# 遍历子文件夹
for subfolder in subfolders:
    # 找到子文件夹中所有的图片文件
    image_files = [os.path.join(folder_path, subfolder, file) for file in os.listdir(os.path.join(folder_path, subfolder)) if file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".bmp")]
 
    # 把文件移动到父文件夹中
    for image_file in image_files:
        shutil.move(image_file, os.path.join(folder_path, os.path.basename(image_file)))

  

posted @   雄子  阅读(49)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示