大数据开发之keras代码框架应用
总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间。
在此做下代码框架应用笔记
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | class VGGNetwork: def append_vgg_network( self , x_in, true_X_input): return x #x is output of VGG def load_vgg_weight( self , model): return model class DiscriminatorNetwork: def append_gan_network( self , true_X_input): return x class GenerativeNetwork: def create_sr_model( self , ip): return x def get_generator_output( self , input_img, srgan_model): return self .output_func([input_img]) class SRGANNetwork: def build_srgan_pretrain_model( self ): return self .srgan_model_ def build_discriminator_pretrain_model( self ): return self .discriminative_model_ def build_srgan_model( self ): return self .srgan_model_ def pre_train_srgan( self , image_dir, nb_images = 50000 , nb_epochs = 1 , use_small_srgan = False ): for i in range (nb_epochs): for x in datagen.flow_from_directory if iteration % 50 = = 0 and iteration ! = 0 validation / / print psnr Train only generator + vgg network if iteration % 1000 = = 0 and iteration ! = 0 Saving model weights def pre_train_discriminator( self , image_dir, nb_images = 50000 , nb_epochs = 1 , batch_size = 128 ): for i in range (nb_epochs): for x in datagen.flow_from_directory Train only discriminator if iteration % 1000 = = 0 and iteration ! = 0 Saving model weights def train_full_model( self , image_dir, nb_images = 50000 , nb_epochs = 10 ): for i in range (nb_epochs): for x in datagen.flow_from_directory if iteration % 50 = = 0 and iteration ! = 0 validation / / print psnr if iteration % 1000 = = 0 and iteration ! = 0 Saving model weights Train only discriminator, disable training of srgan Train only generator, disable training of discriminator if __name__ = = "__main__" : from keras.utils.visualize_util import plot # Path to MS COCO dataset coco_path = r "D:\Yue\Documents\Dataset\coco2014\train2014" ''' Base Network manager for the SRGAN model Width / Height = 32 to reduce the memory requirement for the discriminator. Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as Instance Normalization (batch norm with 1 input image) which speeds up training slightly. ''' srgan_network = SRGANNetwork(img_width = 32 , img_height = 32 , batch_size = 1 ) srgan_network.build_srgan_model() #plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True) # Pretrain the SRGAN network #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1) # Pretrain the discriminator network #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16) # Fully train the SRGAN with VGG loss and Discriminator loss srgan_network.train_full_model(coco_path, nb_images = 80000 , nb_epochs = 5 ) |
千行代码,Bug何处藏。 纵使上线又怎样,朝令改,夕断肠。
分类:
Python开发笔记
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· AI技术革命,工作效率10个最佳AI工具