=================================版权声明=================================
版权声明:本文为博主原创文章 未经许可不得转载
请通过博客平台的消息联系我。
未经作者授权勿用于学术性引用。
未经作者授权勿用于商业出版、商业印刷、商业引用以及其他商业用途。
本文不定期修正完善,为保证内容正确,建议移步原文处阅读。 <--------总有一天我要自己做一个模板干掉这只土豆
本文链接:https://www.cnblogs.com/wlsandwho/p/18682089
耻辱墙:http://www.cnblogs.com/wlsandwho/p/4206472.html
=======================================================================
看书搞猫狗识别。
===========================
使用tensorflow 2.15.0
===========================
使用预训练模型,代码没问题。
使用预训练模型+模型增强,回调函数进行保存模型时,会报错。
使用预训练模型+微调模型,回调函数进行保存模型时,会报错。
===========================
错误的主要提示如下:
TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.
我看到很多人都在网上问为什么应该怎么解决,但是似乎没有人给出好的方案。
我这里把自己的方案整理一下。
===========================
解决问题的关键在于找到关键的问题。
对于这个问题,我们通过对比应该能想到,是自己把两个东西搓到一块时,这里面有东西不能序列化。
===========================
根据官方资料https://keras.io/guides/serialization_and_saving/#custom-objects这个问题有很多种解决方法。这里随便记录一下。
===========================
方法1:把preprocess_input挪出模型
既然没法保存,那就不保存嘛。
此处以“使用预训练模型+模型增强,回调函数进行保存模型时,会报错。”为例,代码如下:
1 import os 2 import shutil 3 import pathlib 4 5 import keras.callbacks 6 import tensorflow as tf 7 import matplotlib.pyplot as plt 8 import numpy as np 9 from PIL import Image 10 11 #路径等 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages") 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub") 14 # savefile="/kaggle/working/convnet_from_scratch.keras" 15 # logdir="/kaggle/working/tensorboard_mycatvsdog" 16 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages") 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub") 19 savefile="vgg16_data_improve_convnet_from_scratch.keras" 20 logdir="vgg16_data_improve_tensorboard_mycatvsdog" 21 22 #数据处理 23 import cv2 24 def is_legal_image(image_path): 25 # 检查图片文件是否存在 26 if not os.path.exists(image_path): 27 print(f"图片文件不存在:{image_path}") 28 return False 29 30 # 检查图片文件大小 31 file_size = os.path.getsize(image_path) 32 if file_size == 0: 33 print(f"图片文件大小为零:{image_path}") 34 return False 35 36 # 检查图片文件格式 37 valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] 38 file_extension = os.path.splitext(image_path)[1].lower() 39 if file_extension not in valid_extensions: 40 print(f"无效的图片文件格式:{file_extension} of {image_path}") 41 return False 42 43 # 检查图片文件的完整性 44 try: 45 img=cv2.imread(image_path)#img=Image.open(image_path) 46 if img is None: 47 print(f"无法打开图片文件:{image_path}") 48 return False 49 #imgarr=np.empty(img.shape) 50 #imgarr[:]=img 51 except Exception as e: 52 print(f"无法打开图片文件:{e} of {image_path}") 53 return False 54 55 # 检查图片尺寸 56 height, width, _ = img.shape 57 if height <= 0 or width <= 0: 58 print(f"图片尺寸异常:{image_path},height:{height}width:{width}") 59 return False 60 61 # 检查图片的颜色通道 62 num_channels = len(img.shape) 63 if num_channels != 3: 64 print(f"无效的颜色通道数:{num_channels} of {image_path}") 65 return False 66 67 # 检查图片像素值范围 68 min_pixel_value = img.min() 69 max_pixel_value = img.max() 70 if min_pixel_value < 0 or max_pixel_value > 255: 71 print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}") 72 return False 73 74 # 目前认为图片正常 75 img=None 76 return True 77 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx): 79 for category in ("Cat","Dog"): 80 dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}" 81 os.makedirs(dstdir) 82 filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)] 83 for f in filenames: 84 pathfile=oridir/category/f 85 if is_legal_image(pathfile) : 86 shutil.copy(src=oridir/category/f,dst=dstdir/f) 87 else: 88 pass 89 90 if not pathlib.Path(new_dir).exists(): 91 print("make a sub data set for train,validatation and test") 92 make_sub_dat(ori_dir,new_dir,"train",0,1000) 93 make_sub_dat(ori_dir,new_dir,"validation",1000,1500) 94 make_sub_dat(ori_dir,new_dir,"test",1500,2500) 95 print("done") 96 else: 97 print("the sub data set existing,skip") 98 #quit(2) 99 100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32) 101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32) 102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32) 103 104 #x=tf.keras.applications.vgg16.preprocess_input(x) 105 train_dataset=train_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab)) 106 validation_dataset=validation_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab)) 107 test_dataset=test_dataset.map(lambda dat,lab:(tf.keras.applications.vgg16.preprocess_input(dat),lab)) 108 109 if pathlib.Path(savefile).exists(): 110 model=tf.keras.models.load_model(savefile) 111 testloss, testacc = model.evaluate(test_dataset) 112 print(f"testloss:{testloss},testacc:{testacc}") 113 quit(2) 114 115 #预训练模型 116 convbase=tf.keras.applications.vgg16.VGG16(include_top=False, 117 weights="imagenet", 118 input_shape=(180,180,3)) 119 convbase.trainable=False#冻结卷积基 120 convbase.summary() 121 122 #模型 123 data_augmentation=tf.keras.Sequential([ 124 tf.keras.layers.RandomFlip("horizontal"), 125 tf.keras.layers.RandomRotation(0.1), 126 tf.keras.layers.RandomZoom(0.2), 127 ]) 128 129 inputs=tf.keras.Input(shape=(180,180,3)) 130 x=data_augmentation(inputs) 131 x=convbase(x) 132 x=tf.keras.layers.Flatten()(x) 133 x=tf.keras.layers.Dense(256)(x) 134 x=tf.keras.layers.Dropout(0.5)(x) 135 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x) 136 model=tf.keras.Model(inputs,outputs) 137 138 model.summary() 139 140 model.compile(optimizer="rmsprop", 141 loss=tf.keras.losses.binary_crossentropy, 142 metrics=[tf.keras.metrics.binary_accuracy]) 143 144 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile, 145 save_best_only=True, 146 monitor="val_loss"), 147 tf.keras.callbacks.TensorBoard(log_dir=logdir)] 148 149 history=model.fit(train_dataset,epochs=50,#batch_size=32,已经由image_dataset_from_directory提供了 150 validation_data=validation_dataset, 151 callbacks=mycallbacks) 152 153 #print(history.history.keys()) 154 accuracy = history.history["binary_accuracy"] 155 val_accuracy = history.history["val_binary_accuracy"] 156 loss = history.history["loss"] 157 val_loss = history.history["val_loss"] 158 epochs = range(1, len(accuracy) + 1) 159 plt.plot(epochs, accuracy, "bo", label="Training accuracy") 160 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy") 161 plt.title("Training and validation accuracy") 162 plt.legend() 163 plt.figure() 164 plt.plot(epochs, loss, "bo", label="Training loss") 165 plt.plot(epochs, val_loss, "b", label="Validation loss") 166 plt.title("Training and validation loss") 167 plt.legend() 168 plt.show()
这个代码巧妙利用了矩阵变换与颜色通道无关。
所以不是个通用的做法。
===========================
方法2:使用Lambda
通过 Lambda
层将 preprocess_input
包裹起来。这可以在序列化过程中直接处理。
由于Lambda包装了可执行的内容,所以应当确保这个可执行内容是安全的。
所以在加载模型时要求关闭safe_mode。
个人觉得这种做起来快,后期有空可以慢慢正规化。
此处以“使用预训练模型+微调模型,回调函数进行保存模型时,会报错。”为例,给出代码:
1 import os 2 import shutil 3 import pathlib 4 5 import keras.callbacks 6 import tensorflow as tf 7 import matplotlib.pyplot as plt 8 import numpy as np 9 from PIL import Image 10 11 #路径等 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages") 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub") 14 # savefile="/kaggle/working/vgg16_fine_tune_convnet_from_scratch.keras" 15 # logdir="/kaggle/working/vgg16_fine_tune_tensorboard_mycatvsdog" 16 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages") 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub") 19 savefile="vgg16_fine_tune2_convnet_from_scratch.keras" 20 logdir="vgg16_fine_tune2_tensorboard_mycatvsdog" 21 22 #数据处理 23 import cv2 24 def is_legal_image(image_path): 25 # 检查图片文件是否存在 26 if not os.path.exists(image_path): 27 print(f"图片文件不存在:{image_path}") 28 return False 29 30 # 检查图片文件大小 31 file_size = os.path.getsize(image_path) 32 if file_size == 0: 33 print(f"图片文件大小为零:{image_path}") 34 return False 35 36 # 检查图片文件格式 37 valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] 38 file_extension = os.path.splitext(image_path)[1].lower() 39 if file_extension not in valid_extensions: 40 print(f"无效的图片文件格式:{file_extension} of {image_path}") 41 return False 42 43 # 检查图片文件的完整性 44 try: 45 img=cv2.imread(image_path)#img=Image.open(image_path) 46 if img is None: 47 print(f"无法打开图片文件:{image_path}") 48 return False 49 #imgarr=np.empty(img.shape) 50 #imgarr[:]=img 51 except Exception as e: 52 print(f"无法打开图片文件:{e} of {image_path}") 53 return False 54 55 # 检查图片尺寸 56 height, width, _ = img.shape 57 if height <= 0 or width <= 0: 58 print(f"图片尺寸异常:{image_path},height:{height}width:{width}") 59 return False 60 61 # 检查图片的颜色通道 62 num_channels = len(img.shape) 63 if num_channels != 3: 64 print(f"无效的颜色通道数:{num_channels} of {image_path}") 65 return False 66 67 # 检查图片像素值范围 68 min_pixel_value = img.min() 69 max_pixel_value = img.max() 70 if min_pixel_value < 0 or max_pixel_value > 255: 71 print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}") 72 return False 73 74 # 目前认为图片正常 75 img=None 76 return True 77 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx): 79 for category in ("Cat","Dog"): 80 dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}" 81 os.makedirs(dstdir) 82 filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)] 83 for f in filenames: 84 pathfile=oridir/category/f 85 if is_legal_image(pathfile) : 86 shutil.copy(src=oridir/category/f,dst=dstdir/f) 87 else: 88 pass 89 90 if not pathlib.Path(new_dir).exists(): 91 print("make a sub data set for train,validatation and test") 92 make_sub_dat(ori_dir,new_dir,"train",0,1000) 93 make_sub_dat(ori_dir,new_dir,"validation",1000,1500) 94 make_sub_dat(ori_dir,new_dir,"test",1500,2500) 95 print("done") 96 else: 97 print("the sub data set existing,skip") 98 #quit(2) 99 100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32) 101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32) 102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32) 103 104 if pathlib.Path(savefile).exists(): 105 #model=tf.keras.models.load_model(savefile,custom_objects={"my_preprocess_input":my_preprocess_input}) 106 model = tf.keras.models.load_model(savefile,safe_mode=False) 107 testloss, testacc = model.evaluate(test_dataset) 108 print(f"testloss:{testloss},testacc:{testacc}") 109 quit(2) 110 111 #预训练模型 112 convbase=tf.keras.applications.vgg16.VGG16(include_top=False, 113 weights="imagenet", 114 input_shape=(180,180,3)) 115 #冻结卷积基的最后几层 116 #convbase.summary() 117 convbase.trainable=True 118 for lay in convbase.layers[:-4] : 119 lay.trainable=False 120 # for lay in convbase.layers: 121 # print(f"{lay.name}: {lay.trainable}") 122 convbase.summary() 123 124 #模型 125 126 data_augmentation=tf.keras.Sequential([ 127 tf.keras.layers.RandomFlip("horizontal"), 128 tf.keras.layers.RandomRotation(0.1), 129 tf.keras.layers.RandomZoom(0.2), 130 ]) 131 132 inputs=tf.keras.Input(shape=(180,180,3)) 133 x=data_augmentation(inputs) 134 #x=tf.keras.applications.vgg16.preprocess_input(x) 135 #x=my_preprocess_input(x) 136 x=tf.keras.layers.Lambda(tf.keras.applications.vgg16.preprocess_input)(x) 137 x=convbase(x) 138 x=tf.keras.layers.Flatten()(x) 139 x=tf.keras.layers.Dense(256)(x) 140 x=tf.keras.layers.Dropout(0.5)(x) 141 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x) 142 model=tf.keras.Model(inputs,outputs) 143 model.summary() 144 145 model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5), 146 loss=tf.keras.losses.binary_crossentropy, 147 metrics=[tf.keras.metrics.binary_accuracy]) 148 #model.save(savefile)#TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method. 149 150 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile, 151 save_best_only=True, 152 monitor="val_loss"), 153 tf.keras.callbacks.TensorBoard(log_dir=logdir)] 154 155 history=model.fit(train_dataset,epochs=1, 156 validation_data=validation_dataset, 157 callbacks=mycallbacks) 158 159 #print(history.history.keys()) 160 accuracy = history.history["binary_accuracy"] 161 val_accuracy = history.history["val_binary_accuracy"] 162 loss = history.history["loss"] 163 val_loss = history.history["val_loss"] 164 epochs = range(1, len(accuracy) + 1) 165 plt.plot(epochs, accuracy, "bo", label="Training accuracy") 166 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy") 167 plt.title("Training and validation accuracy") 168 plt.legend() 169 plt.figure() 170 plt.plot(epochs, loss, "bo", label="Training loss") 171 plt.plot(epochs, val_loss, "b", label="Validation loss") 172 plt.title("Training and validation loss") 173 plt.legend() 174 plt.show()
===========================
方法3:提供序列化所需内容
既然想保存,那就照着要求做嘛。
此处以“使用预训练模型+微调模型,回调函数进行保存模型时,会报错。”为例,给出代码:
1 import os 2 import shutil 3 import pathlib 4 5 import keras.callbacks 6 import tensorflow as tf 7 import matplotlib.pyplot as plt 8 import numpy as np 9 from PIL import Image 10 11 #路径等 12 # ori_dir=pathlib.Path("/kaggle/input/dogsandcats/PetImages") 13 # new_dir=pathlib.Path("/kaggle/working/PetImages_sub") 14 # savefile="/kaggle/working/vgg16_fine_tune_convnet_from_scratch.keras" 15 # logdir="/kaggle/working/vgg16_fine_tune_tensorboard_mycatvsdog" 16 17 ori_dir=pathlib.Path(r"E:\my_AI_stuff\some_data\cats-vs-dogs\PetImages") 18 new_dir=pathlib.Path(r"E:\my_AI_stuff\mytest\PetImages_sub") 19 savefile="vgg16_fine_tune4_convnet_from_scratch.keras" 20 logdir="vgg16_fine_tune4_tensorboard_mycatvsdog" 21 22 #数据处理 23 import cv2 24 def is_legal_image(image_path): 25 # 检查图片文件是否存在 26 if not os.path.exists(image_path): 27 print(f"图片文件不存在:{image_path}") 28 return False 29 30 # 检查图片文件大小 31 file_size = os.path.getsize(image_path) 32 if file_size == 0: 33 print(f"图片文件大小为零:{image_path}") 34 return False 35 36 # 检查图片文件格式 37 valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] 38 file_extension = os.path.splitext(image_path)[1].lower() 39 if file_extension not in valid_extensions: 40 print(f"无效的图片文件格式:{file_extension} of {image_path}") 41 return False 42 43 # 检查图片文件的完整性 44 try: 45 img=cv2.imread(image_path)#img=Image.open(image_path) 46 if img is None: 47 print(f"无法打开图片文件:{image_path}") 48 return False 49 #imgarr=np.empty(img.shape) 50 #imgarr[:]=img 51 except Exception as e: 52 print(f"无法打开图片文件:{e} of {image_path}") 53 return False 54 55 # 检查图片尺寸 56 height, width, _ = img.shape 57 if height <= 0 or width <= 0: 58 print(f"图片尺寸异常:{image_path},height:{height}width:{width}") 59 return False 60 61 # 检查图片的颜色通道 62 num_channels = len(img.shape) 63 if num_channels != 3: 64 print(f"无效的颜色通道数:{num_channels} of {image_path}") 65 return False 66 67 # 检查图片像素值范围 68 min_pixel_value = img.min() 69 max_pixel_value = img.max() 70 if min_pixel_value < 0 or max_pixel_value > 255: 71 print(f"图片像素值异常:{image_path},min_pixel_value:{min_pixel_value},max_pixel_value:{max_pixel_value}") 72 return False 73 74 # 目前认为图片正常 75 img=None 76 return True 77 78 def make_sub_dat(oridir,newdir,subdirname,bidx,eidx): 79 for category in ("Cat","Dog"): 80 dstdir=newdir/subdirname/category#f"{newdir}\\{subdirname}\\{category}" 81 os.makedirs(dstdir) 82 filenames=[f"{i}.jpg" for i in range(bidx,eidx)]#filenames=[f"{i}.jpg" for i in range(bidx,eidx)] 83 for f in filenames: 84 pathfile=oridir/category/f 85 if is_legal_image(pathfile) : 86 shutil.copy(src=oridir/category/f,dst=dstdir/f) 87 else: 88 pass 89 90 if not pathlib.Path(new_dir).exists(): 91 print("make a sub data set for train,validatation and test") 92 make_sub_dat(ori_dir,new_dir,"train",0,1000) 93 make_sub_dat(ori_dir,new_dir,"validation",1000,1500) 94 make_sub_dat(ori_dir,new_dir,"test",1500,2500) 95 print("done") 96 else: 97 print("the sub data set existing,skip") 98 #quit(2) 99 100 train_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"train",image_size=(180,180),batch_size=32) 101 validation_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"validation",image_size=(180,180),batch_size=32) 102 test_dataset=tf.keras.preprocessing.image_dataset_from_directory(new_dir/"test",image_size=(180,180),batch_size=32) 103 104 @keras.saving.register_keras_serializable(package="MyLayers") 105 class MyPreprocessInput(tf.keras.layers.Layer): 106 def __init__(self,**kwargs): 107 super().__init__(**kwargs) 108 109 def call(self,data_inputs): 110 return tf.keras.applications.vgg16.preprocess_input(data_inputs) 111 112 def get_config(self): 113 config=super().get_config() 114 return config 115 116 @classmethod 117 def from_config(cls, config): 118 return cls(**config) 119 120 121 if pathlib.Path(savefile).exists(): 122 model=tf.keras.models.load_model(savefile) 123 testloss, testacc = model.evaluate(test_dataset) 124 print(f"testloss:{testloss},testacc:{testacc}") 125 quit(2) 126 127 #预训练模型 128 convbase=tf.keras.applications.vgg16.VGG16(include_top=False, 129 weights="imagenet", 130 input_shape=(180,180,3)) 131 #冻结卷积基的最后几层 132 #convbase.summary() 133 convbase.trainable=True 134 for lay in convbase.layers[:-4] : 135 lay.trainable=False 136 # for lay in convbase.layers: 137 # print(f"{lay.name}: {lay.trainable}") 138 convbase.summary() 139 140 #模型 141 142 data_augmentation=tf.keras.Sequential([ 143 tf.keras.layers.RandomFlip("horizontal"), 144 tf.keras.layers.RandomRotation(0.1), 145 tf.keras.layers.RandomZoom(0.2), 146 ]) 147 148 inputs=tf.keras.Input(shape=(180,180,3)) 149 x=data_augmentation(inputs) 150 x=MyPreprocessInput()(x) 151 x=convbase(x) 152 x=tf.keras.layers.Flatten()(x) 153 x=tf.keras.layers.Dense(256)(x) 154 x=tf.keras.layers.Dropout(0.5)(x) 155 outputs=tf.keras.layers.Dense(1,activation="sigmoid")(x) 156 model=tf.keras.Model(inputs,outputs) 157 model.summary() 158 159 model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5), 160 loss=tf.keras.losses.binary_crossentropy, 161 metrics=[tf.keras.metrics.binary_accuracy]) 162 #model.save(savefile)#TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method. 163 164 mycallbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=savefile, 165 save_best_only=True, 166 monitor="val_loss"), 167 tf.keras.callbacks.TensorBoard(log_dir=logdir)] 168 169 history=model.fit(train_dataset,epochs=1,#50 170 validation_data=validation_dataset, 171 callbacks=mycallbacks) 172 173 #print(history.history.keys()) 174 accuracy = history.history["binary_accuracy"] 175 val_accuracy = history.history["val_binary_accuracy"] 176 loss = history.history["loss"] 177 val_loss = history.history["val_loss"] 178 epochs = range(1, len(accuracy) + 1) 179 plt.plot(epochs, accuracy, "bo", label="Training accuracy") 180 plt.plot(epochs, val_accuracy, "b", label="Validation accuracy") 181 plt.title("Training and validation accuracy") 182 plt.legend() 183 plt.figure() 184 plt.plot(epochs, loss, "bo", label="Training loss") 185 plt.plot(epochs, val_loss, "b", label="Validation loss") 186 plt.title("Training and validation loss") 187 plt.legend() 188 plt.show()
===================
至此,常规的做法就这些了。我也不会歪门邪道的方法。
===================
这个看官方文档自己也能写出来https://keras.io/guides/serialization_and_saving/#custom-objects
===================
好了,障碍扫清,各种方法我也都演示了一遍。