Keras 单机多卡训练模型
注意:此模式下不能用fit_generator() 方式训练
""" GPU test """ import os import sys os.system('pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras==2.3.1') from tensorflow.keras import Sequential from tensorflow.keras.models import Model from tensorflow.keras.layers import Input,Dense from tensorflow.keras import layers from tensorflow.keras.callbacks import ModelCheckpoint,EarlyStopping import tensorflow as tf from tensorflow import keras import numpy as np import pickle import time checkpoint_save_dir = "/models/embedding_recall/ckpt" best_model_name = "best_model.hdf5" if not os.path.exists(checkpoint_save_dir): os.makedirs(checkpoint_save_dir) with open(r"/models/embedding_recall/resources/minist.pkl","rb") as fr: data = pickle.load(fr) def make_or_restore_model(): # Either restore the latest model, or create a fresh one # if there is no checkpoint available. checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)] if checkpoints: latest_checkpoint = max(checkpoints, key=os.path.getctime) print("Restoring from", latest_checkpoint) return keras.models.load_model(latest_checkpoint) print("Creating a new model") return get_compiled_model() def get_compiled_model(): inputs = Input(shape=(784,)) inputs.shape inputs.dtype dense = Dense(64, activation="relu") x = dense(inputs) x = Dense(64, activation="relu")(x) outputs = Dense(10)(x) model = Model(inputs=inputs, outputs=outputs, name="my_model") model.compile( loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer = keras.optimizers.RMSprop(), metrics = ["accuracy"],) return model def make_or_restore_model(checkpoint_save_dir, model_name): # Either restore the latest model, or create a fresh one # if there is no checkpoint available. if checkpoint_save_dir: latest_checkpoint = os.path.join(checkpoint_save_dir, model_name) print("Restoring from", latest_checkpoint) return keras.models.load_model(latest_checkpoint) else: return None def run_training(epochs): strategy = tf.distribute.MirroredStrategy() print("Number of devices:{}".format(strategy.num_replicas_in_sync)) with strategy.scope(): model = get_compiled_model() (x_train, y_train),(x_test, y_test) = data[0],data[1] x_train = x_train.reshape(60000, 784).astype("float32")/255 x_test = x_test.reshape(10000, 784).astype("float32")/255 early_stop = EarlyStopping(monitor='loss', patience=3, verbose=1) checkpoint = ModelCheckpoint(os.path.join(checkpoint_save_dir, best_model_name), monitor='loss', verbose=1, save_best_only=True, mode='min') callbacks_list = [checkpoint, early_stop] t1 = time.time() history = model.fit(x_train, y_train, batch_size=100, epochs=epochs, callbacks=callbacks_list) t2 = time.time() # test_scores = model.evaluate(x_test, y_test, batch_size=100,verbose=2) # print("test loss:{}".format(test_scores[0])) # print("test acc:{}".format(test_scores[1])) # print("total spent:{}".format(t2-t1)) def continue_training(epochs): strategy = tf.distribute.MirroredStrategy() print("Number of devices:{}".format(strategy.num_replicas_in_sync)) # with strategy.scope(): model = make_or_restore_model(checkpoint_save_dir, best_model_name) (x_train, y_train),(x_test, y_test) = data[0],data[1] x_train = x_train.reshape(60000, 784).astype("float32")/255 x_test = x_test.reshape(10000, 784).astype("float32")/255 early_stop = EarlyStopping(monitor='loss', patience=3, verbose=1) checkpoint = ModelCheckpoint(os.path.join(checkpoint_save_dir, best_model_name), monitor='loss', verbose=1, save_best_only=True, mode='min') callbacks_list = [checkpoint, early_stop] t1 = time.time() history = model.fit(x_train, y_train, batch_size=100, epochs=epochs, callbacks=callbacks_list) t2 = time.time() run_training(epochs=5) # continue_training(epochs=100)
时刻记着自己要成为什么样的人!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)