大数据开发之keras代码框架应用

     总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间。

    在此做下代码框架应用笔记

     

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class VGGNetwork:
    def append_vgg_network(self, x_in, true_X_input):
        return x #x is output of VGG
    def load_vgg_weight(self, model):
        return model
class DiscriminatorNetwork:
    def append_gan_network(self, true_X_input):
        return x
class GenerativeNetwork:
    def create_sr_model(self, ip):
        return x
    def get_generator_output(self, input_img, srgan_model):
        return self.output_func([input_img])
class SRGANNetwork:
    def build_srgan_pretrain_model(self):
        return self.srgan_model_
    def build_discriminator_pretrain_model(self):
        return self.discriminative_model_
    def build_srgan_model(self):
        return self.srgan_model_
    def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False):
        for i in range(nb_epochs):
            for x in datagen.flow_from_directory
                if iteration % 50 == 0 and iteration != 0
                    validation//print psnr
                Train only generator + vgg network
                if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
    def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128):
        for i in range(nb_epochs):
             for x in datagen.flow_from_directory
                 Train only discriminator
                 if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
    def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10):  
        for i in range(nb_epochs):
            for x in datagen.flow_from_directory
                if iteration % 50 == 0 and iteration != 0
                    validation//print psnr
                if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
                Train only discriminator, disable training of srgan
                Train only generator, disable training of discriminator
if __name__ == "__main__":
    from keras.utils.visualize_util import plot
 
    # Path to MS COCO dataset
    coco_path = r"D:\Yue\Documents\Dataset\coco2014\train2014"
 
    '''
    Base Network manager for the SRGAN model
 
    Width / Height = 32 to reduce the memory requirement for the discriminator.
 
    Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as
    Instance Normalization (batch norm with 1 input image) which speeds up training slightly.
    '''
 
    srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
    srgan_network.build_srgan_model()
    #plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True)
 
    # Pretrain the SRGAN network
    #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1)
 
    # Pretrain the discriminator network
    #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16)
 
    # Fully train the SRGAN with VGG loss and Discriminator loss
    srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)

  

posted @   圆柱模板  阅读(377)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· AI技术革命,工作效率10个最佳AI工具
点击右上角即可分享
微信分享提示