=================================版权声明=================================

版权声明:本文为博主原创文章 未经许可不得转载 

请通过博客平台的消息联系我。

未经作者授权勿用于学术性引用。

未经作者授权勿用于商业出版、商业印刷、商业引用以及其他商业用途。                   

 

本文不定期修正完善,为保证内容正确,建议移步原文处阅读。                                                               <--------总有一天我要自己做一个模板干掉这只土豆

本文链接:https://www.cnblogs.com/wlsandwho/p/18682089

耻辱墙:http://www.cnblogs.com/wlsandwho/p/4206472.html

=======================================================================

看书搞猫狗识别。

===========================

使用tensorflow 2.15.0

===========================

使用预训练模型,代码没问题。

使用预训练模型+模型增强,回调函数进行保存模型时,会报错。

使用预训练模型+微调模型,回调函数进行保存模型时,会报错。

===========================

错误的主要提示如下:

TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method. 

我看到很多人都在网上问为什么应该怎么解决,但是似乎没有人给出好的方案。

我这里把自己的方案整理一下。

===========================

解决问题的关键在于找到关键的问题。

对于这个问题,我们通过对比应该能想到,是自己把两个东西搓到一块时,这里面有东西不能序列化。

===========================

根据官方资料https://keras.io/guides/serialization_and_saving/#custom-objects这个问题有很多种解决方法。这里随便记录一下。

===========================

方法1:把preprocess_input挪出模型

既然没法保存,那就不保存嘛。

此处以“使用预训练模型+模型增强,回调函数进行保存模型时,会报错。”为例,代码如下:

  1 import os
  2 import shutil
  3 import pathlib
  4 
  5 import keras.callbacks
  6 import tensorflow as tf
  7 import matplotlib.pyplot as plt
  8 import numpy as np
  9 from PIL import Image
 10 
 11 #路径等
 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages")
 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub")
 14 # savefile="/kaggle/working/convnet_from_scratch.keras"
 15 # logdir="/kaggle/working/tensorboard_mycatvsdog"
 16 
 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages")
 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub")
 19 savefile="vgg16_data_improve_convnet_from_scratch.keras"
 20 logdir="vgg16_data_improve_tensorboard_mycatvsdog"
 21 
 22 #数据处理
 23 import cv2
 24 def is_legal_image(image_path):
 25     # 检查图片文件是否存在
 26     if not os.path.exists(image_path):
 27         print(f"图片文件不存在:{image_path}")
 28         return False
 29 
 30     # 检查图片文件大小
 31     file_size = os.path.getsize(image_path)
 32     if file_size == 0:
 33         print(f"图片文件大小为零:{image_path}")
 34         return False
 35 
 36     # 检查图片文件格式
 37     valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
 38     file_extension = os.path.splitext(image_path)[1].lower()
 39     if file_extension not in valid_extensions:
 40         print(f"无效的图片文件格式:{file_extension} of {image_path}")
 41         return False
 42 
 43     # 检查图片文件的完整性
 44     try:
 45         img=cv2.imread(image_path)#img=Image.open(image_path)
 46         if img is None:
 47             print(f"无法打开图片文件:{image_path}")
 48             return False
 49         #imgarr=np.empty(img.shape)
 50         #imgarr[:]=img
 51     except Exception as e:
 52         print(f"无法打开图片文件:{e} of {image_path}")
 53         return False
 54 
 55     # 检查图片尺寸
 56     height, width, _ = img.shape
 57     if height <= 0 or width <= 0:
 58         print(f"图片尺寸异常:{image_path},height:{height}width:{width}")
 59         return False
 60 
 61     # 检查图片的颜色通道
 62     num_channels = len(img.shape)
 63     if num_channels != 3:
 64         print(f"无效的颜色通道数:{num_channels} of {image_path}")
 65         return False
 66 
 67     # 检查图片像素值范围
 68     min_pixel_value = img.min()
 69     max_pixel_value = img.max()
 70     if min_pixel_value < 0 or max_pixel_value > 255:
 71         print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}")
 72         return False
 73 
 74     # 目前认为图片正常
 75     img=None
 76     return True
 77 
 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx):
 79     for category in ("Cat","Dog"):
 80         dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}"
 81         os.makedirs(dstdir)
 82         filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)]
 83         for f in filenames:
 84             pathfile=oridir/category/f
 85             if is_legal_image(pathfile) :
 86                 shutil.copy(src=oridir/category/f,dst=dstdir/f)
 87             else:
 88                 pass
 89 
 90 if not pathlib.Path(new_dir).exists():
 91     print("make a sub data set for train,validatation and test")
 92     make_sub_dat(ori_dir,new_dir,"train",0,1000)
 93     make_sub_dat(ori_dir,new_dir,"validation",1000,1500)
 94     make_sub_dat(ori_dir,new_dir,"test",1500,2500)
 95     print("done")
 96 else:
 97     print("the sub data set existing,skip")
 98 #quit(2)
 99 
100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32)
101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32)
102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32)
103 
104 #x=tf.keras.applications.vgg16.preprocess_input(x)
105 train_dataset=train_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab))
106 validation_dataset=validation_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab))
107 test_dataset=test_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab))
108 
109 if pathlib.Path(savefile).exists():
110     model=tf.keras.models.load_model(savefile)
111     testloss, testacc = model.evaluate(test_dataset)
112     print(f"testloss:{testloss},testacc:{testacc}")
113     quit(2)
114 
115 #预训练模型
116 convbase=tf.keras.applications.vgg16.VGG16(include_top=False,
117                                            weights="imagenet",
118                                            input_shape=(180,180,3))
119 convbase.trainable=False#冻结卷积基
120 convbase.summary()
121 
122 #模型
123 data_augmentation=tf.keras.Sequential([
124     tf.keras.layers.RandomFlip("horizontal"),
125     tf.keras.layers.RandomRotation(0.1),
126     tf.keras.layers.RandomZoom(0.2),
127 ])
128 
129 inputs=tf.keras.Input(shape=(180,180,3))
130 x=data_augmentation(inputs)
131 x=convbase(x)
132 x=tf.keras.layers.Flatten()(x)
133 x=tf.keras.layers.Dense(256)(x)
134 x=tf.keras.layers.Dropout(0.5)(x)
135 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x)
136 model=tf.keras.Model(inputs,outputs)
137 
138 model.summary()
139 
140 model.compile(optimizer="rmsprop",
141               loss=tf.keras.losses.binary_crossentropy,
142               metrics=[tf.keras.metrics.binary_accuracy])
143 
144 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile,
145                                                 save_best_only=True,
146                                                 monitor="val_loss"),
147              tf.keras.callbacks.TensorBoard(log_dir=logdir)]
148 
149 history=model.fit(train_dataset,epochs=50,#batch_size=32,已经由image_dataset_from_directory提供了
150                   validation_data=validation_dataset,
151                   callbacks=mycallbacks)
152 
153 #print(history.history.keys())
154 accuracy = history.history["binary_accuracy"]
155 val_accuracy = history.history["val_binary_accuracy"]
156 loss = history.history["loss"]
157 val_loss = history.history["val_loss"]
158 epochs = range(1, len(accuracy) + 1)
159 plt.plot(epochs, accuracy, "bo", label="Training accuracy")
160 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
161 plt.title("Training and validation accuracy")
162 plt.legend()
163 plt.figure()
164 plt.plot(epochs, loss, "bo", label="Training loss")
165 plt.plot(epochs, val_loss, "b", label="Validation loss")
166 plt.title("Training and validation loss")
167 plt.legend()
168 plt.show()

这个代码巧妙利用了矩阵变换与颜色通道无关。

所以不是个通用的做法。

===========================

方法2:使用Lambda

通过 Lambda 层将 preprocess_input 包裹起来。这可以在序列化过程中直接处理。

由于Lambda包装了可执行的内容,所以应当确保这个可执行内容是安全的。

所以在加载模型时要求关闭safe_mode。

个人觉得这种做起来快,后期有空可以慢慢正规化。

此处以“使用预训练模型+微调模型,回调函数进行保存模型时,会报错。”为例,给出代码:

  1 import os
  2 import shutil
  3 import pathlib
  4 
  5 import keras.callbacks
  6 import tensorflow as tf
  7 import matplotlib.pyplot as plt
  8 import numpy as np
  9 from PIL import Image
 10 
 11 #路径等
 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages")
 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub")
 14 # savefile="/kaggle/working/vgg16_fine_tune_convnet_from_scratch.keras"
 15 # logdir="/kaggle/working/vgg16_fine_tune_tensorboard_mycatvsdog"
 16 
 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages")
 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub")
 19 savefile="vgg16_fine_tune2_convnet_from_scratch.keras"
 20 logdir="vgg16_fine_tune2_tensorboard_mycatvsdog"
 21 
 22 #数据处理
 23 import cv2
 24 def is_legal_image(image_path):
 25     # 检查图片文件是否存在
 26     if not os.path.exists(image_path):
 27         print(f"图片文件不存在:{image_path}")
 28         return False
 29 
 30     # 检查图片文件大小
 31     file_size = os.path.getsize(image_path)
 32     if file_size == 0:
 33         print(f"图片文件大小为零:{image_path}")
 34         return False
 35 
 36     # 检查图片文件格式
 37     valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
 38     file_extension = os.path.splitext(image_path)[1].lower()
 39     if file_extension not in valid_extensions:
 40         print(f"无效的图片文件格式:{file_extension} of {image_path}")
 41         return False
 42 
 43     # 检查图片文件的完整性
 44     try:
 45         img=cv2.imread(image_path)#img=Image.open(image_path)
 46         if img is None:
 47             print(f"无法打开图片文件:{image_path}")
 48             return False
 49         #imgarr=np.empty(img.shape)
 50         #imgarr[:]=img
 51     except Exception as e:
 52         print(f"无法打开图片文件:{e} of {image_path}")
 53         return False
 54 
 55     # 检查图片尺寸
 56     height, width, _ = img.shape
 57     if height <= 0 or width <= 0:
 58         print(f"图片尺寸异常:{image_path},height:{height}width:{width}")
 59         return False
 60 
 61     # 检查图片的颜色通道
 62     num_channels = len(img.shape)
 63     if num_channels != 3:
 64         print(f"无效的颜色通道数:{num_channels} of {image_path}")
 65         return False
 66 
 67     # 检查图片像素值范围
 68     min_pixel_value = img.min()
 69     max_pixel_value = img.max()
 70     if min_pixel_value < 0 or max_pixel_value > 255:
 71         print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}")
 72         return False
 73 
 74     # 目前认为图片正常
 75     img=None
 76     return True
 77 
 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx):
 79     for category in ("Cat","Dog"):
 80         dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}"
 81         os.makedirs(dstdir)
 82         filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)]
 83         for f in filenames:
 84             pathfile=oridir/category/f
 85             if is_legal_image(pathfile) :
 86                 shutil.copy(src=oridir/category/f,dst=dstdir/f)
 87             else:
 88                 pass
 89 
 90 if not pathlib.Path(new_dir).exists():
 91     print("make a sub data set for train,validatation and test")
 92     make_sub_dat(ori_dir,new_dir,"train",0,1000)
 93     make_sub_dat(ori_dir,new_dir,"validation",1000,1500)
 94     make_sub_dat(ori_dir,new_dir,"test",1500,2500)
 95     print("done")
 96 else:
 97     print("the sub data set existing,skip")
 98 #quit(2)
 99 
100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32)
101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32)
102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32)
103 
104 if pathlib.Path(savefile).exists():
105     #model=tf.keras.models.load_model(savefile,custom_objects={"my_preprocess_input":my_preprocess_input})
106     model = tf.keras.models.load_model(savefile,safe_mode=False)
107     testloss, testacc = model.evaluate(test_dataset)
108     print(f"testloss:{testloss},testacc:{testacc}")
109     quit(2)
110 
111 #预训练模型
112 convbase=tf.keras.applications.vgg16.VGG16(include_top=False,
113                                            weights="imagenet",
114                                            input_shape=(180,180,3))
115 #冻结卷积基的最后几层
116 #convbase.summary()
117 convbase.trainable=True
118 for lay in convbase.layers[:-4] :
119     lay.trainable=False
120 # for lay in convbase.layers:
121 #     print(f"{lay.name}: {lay.trainable}")
122 convbase.summary()
123 
124 #模型
125 
126 data_augmentation=tf.keras.Sequential([
127     tf.keras.layers.RandomFlip("horizontal"),
128     tf.keras.layers.RandomRotation(0.1),
129     tf.keras.layers.RandomZoom(0.2),
130 ])
131 
132 inputs=tf.keras.Input(shape=(180,180,3))
133 x=data_augmentation(inputs)
134 #x=tf.keras.applications.vgg16.preprocess_input(x)
135 #x=my_preprocess_input(x)
136 x=tf.keras.layers.Lambda(tf.keras.applications.vgg16.preprocess_input)(x)
137 x=convbase(x)
138 x=tf.keras.layers.Flatten()(x)
139 x=tf.keras.layers.Dense(256)(x)
140 x=tf.keras.layers.Dropout(0.5)(x)
141 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x)
142 model=tf.keras.Model(inputs,outputs)
143 model.summary()
144 
145 model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
146               loss=tf.keras.losses.binary_crossentropy,
147               metrics=[tf.keras.metrics.binary_accuracy])
148 #model.save(savefile)#TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.
149 
150 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile,
151                                                 save_best_only=True,
152                                                 monitor="val_loss"),
153              tf.keras.callbacks.TensorBoard(log_dir=logdir)]
154 
155 history=model.fit(train_dataset,epochs=1,
156                   validation_data=validation_dataset,
157                   callbacks=mycallbacks)
158 
159 #print(history.history.keys())
160 accuracy = history.history["binary_accuracy"]
161 val_accuracy = history.history["val_binary_accuracy"]
162 loss = history.history["loss"]
163 val_loss = history.history["val_loss"]
164 epochs = range(1, len(accuracy) + 1)
165 plt.plot(epochs, accuracy, "bo", label="Training accuracy")
166 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
167 plt.title("Training and validation accuracy")
168 plt.legend()
169 plt.figure()
170 plt.plot(epochs, loss, "bo", label="Training loss")
171 plt.plot(epochs, val_loss, "b", label="Validation loss")
172 plt.title("Training and validation loss")
173 plt.legend()
174 plt.show()

 

===========================

方法3:提供序列化所需内容

既然想保存,那就照着要求做嘛。

 此处以“使用预训练模型+微调模型,回调函数进行保存模型时,会报错。”为例,给出代码:

  1 import os
  2 import shutil
  3 import pathlib
  4 
  5 import keras.callbacks
  6 import tensorflow as tf
  7 import matplotlib.pyplot as plt
  8 import numpy as np
  9 from PIL import Image
 10 
 11 #路径等
 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages")
 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub")
 14 # savefile="/kaggle/working/vgg16_fine_tune_convnet_from_scratch.keras"
 15 # logdir="/kaggle/working/vgg16_fine_tune_tensorboard_mycatvsdog"
 16 
 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages")
 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub")
 19 savefile="vgg16_fine_tune4_convnet_from_scratch.keras"
 20 logdir="vgg16_fine_tune4_tensorboard_mycatvsdog"
 21 
 22 #数据处理
 23 import cv2
 24 def is_legal_image(image_path):
 25     # 检查图片文件是否存在
 26     if not os.path.exists(image_path):
 27         print(f"图片文件不存在:{image_path}")
 28         return False
 29 
 30     # 检查图片文件大小
 31     file_size = os.path.getsize(image_path)
 32     if file_size == 0:
 33         print(f"图片文件大小为零:{image_path}")
 34         return False
 35 
 36     # 检查图片文件格式
 37     valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
 38     file_extension = os.path.splitext(image_path)[1].lower()
 39     if file_extension not in valid_extensions:
 40         print(f"无效的图片文件格式:{file_extension} of {image_path}")
 41         return False
 42 
 43     # 检查图片文件的完整性
 44     try:
 45         img=cv2.imread(image_path)#img=Image.open(image_path)
 46         if img is None:
 47             print(f"无法打开图片文件:{image_path}")
 48             return False
 49         #imgarr=np.empty(img.shape)
 50         #imgarr[:]=img
 51     except Exception as e:
 52         print(f"无法打开图片文件:{e} of {image_path}")
 53         return False
 54 
 55     # 检查图片尺寸
 56     height, width, _ = img.shape
 57     if height <= 0 or width <= 0:
 58         print(f"图片尺寸异常:{image_path},height:{height}width:{width}")
 59         return False
 60 
 61     # 检查图片的颜色通道
 62     num_channels = len(img.shape)
 63     if num_channels != 3:
 64         print(f"无效的颜色通道数:{num_channels} of {image_path}")
 65         return False
 66 
 67     # 检查图片像素值范围
 68     min_pixel_value = img.min()
 69     max_pixel_value = img.max()
 70     if min_pixel_value < 0 or max_pixel_value > 255:
 71         print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}")
 72         return False
 73 
 74     # 目前认为图片正常
 75     img=None
 76     return True
 77 
 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx):
 79     for category in ("Cat","Dog"):
 80         dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}"
 81         os.makedirs(dstdir)
 82         filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)]
 83         for f in filenames:
 84             pathfile=oridir/category/f
 85             if is_legal_image(pathfile) :
 86                 shutil.copy(src=oridir/category/f,dst=dstdir/f)
 87             else:
 88                 pass
 89 
 90 if not pathlib.Path(new_dir).exists():
 91     print("make a sub data set for train,validatation and test")
 92     make_sub_dat(ori_dir,new_dir,"train",0,1000)
 93     make_sub_dat(ori_dir,new_dir,"validation",1000,1500)
 94     make_sub_dat(ori_dir,new_dir,"test",1500,2500)
 95     print("done")
 96 else:
 97     print("the sub data set existing,skip")
 98 #quit(2)
 99 
100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32)
101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32)
102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32)
103 
104 @keras.saving.register_keras_serializable(package="MyLayers")
105 class MyPreprocessInput(tf.keras.layers.Layer):
106     def __init__(self,**kwargs):
107         super().__init__(**kwargs)
108 
109     def call(self,data_inputs):
110         return tf.keras.applications.vgg16.preprocess_input(data_inputs)
111 
112     def get_config(self):
113         config=super().get_config()
114         return config
115 
116     @classmethod
117     def from_config(cls, config):
118         return cls(**config)
119 
120 
121 if pathlib.Path(savefile).exists():
122     model=tf.keras.models.load_model(savefile)
123     testloss, testacc = model.evaluate(test_dataset)
124     print(f"testloss:{testloss},testacc:{testacc}")
125     quit(2)
126 
127 #预训练模型
128 convbase=tf.keras.applications.vgg16.VGG16(include_top=False,
129                                            weights="imagenet",
130                                            input_shape=(180,180,3))
131 #冻结卷积基的最后几层
132 #convbase.summary()
133 convbase.trainable=True
134 for lay in convbase.layers[:-4] :
135     lay.trainable=False
136 # for lay in convbase.layers:
137 #     print(f"{lay.name}: {lay.trainable}")
138 convbase.summary()
139 
140 #模型
141 
142 data_augmentation=tf.keras.Sequential([
143     tf.keras.layers.RandomFlip("horizontal"),
144     tf.keras.layers.RandomRotation(0.1),
145     tf.keras.layers.RandomZoom(0.2),
146 ])
147 
148 inputs=tf.keras.Input(shape=(180,180,3))
149 x=data_augmentation(inputs)
150 x=MyPreprocessInput()(x)
151 x=convbase(x)
152 x=tf.keras.layers.Flatten()(x)
153 x=tf.keras.layers.Dense(256)(x)
154 x=tf.keras.layers.Dropout(0.5)(x)
155 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x)
156 model=tf.keras.Model(inputs,outputs)
157 model.summary()
158 
159 model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
160               loss=tf.keras.losses.binary_crossentropy,
161               metrics=[tf.keras.metrics.binary_accuracy])
162 #model.save(savefile)#TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.
163 
164 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile,
165                                                 save_best_only=True,
166                                                 monitor="val_loss"),
167              tf.keras.callbacks.TensorBoard(log_dir=logdir)]
168 
169 history=model.fit(train_dataset,epochs=1,#50
170                   validation_data=validation_dataset,
171                   callbacks=mycallbacks)
172 
173 #print(history.history.keys())
174 accuracy = history.history["binary_accuracy"]
175 val_accuracy = history.history["val_binary_accuracy"]
176 loss = history.history["loss"]
177 val_loss = history.history["val_loss"]
178 epochs = range(1, len(accuracy) + 1)
179 plt.plot(epochs, accuracy, "bo", label="Training accuracy")
180 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
181 plt.title("Training and validation accuracy")
182 plt.legend()
183 plt.figure()
184 plt.plot(epochs, loss, "bo", label="Training loss")
185 plt.plot(epochs, val_loss, "b", label="Validation loss")
186 plt.title("Training and validation loss")
187 plt.legend()
188 plt.show()

===================

至此,常规的做法就这些了。我也不会歪门邪道的方法。

===================

这个看官方文档自己也能写出来https://keras.io/guides/serialization_and_saving/#custom-objects

===================

好了,障碍扫清,各种方法我也都演示了一遍。