TensorFlow 2.0 教程22:DCGAN

  1.数据导入和预处理

  (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

  train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')

  train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

  BUFFER_SIZE = 60000

  BATCH_SIZE = 256

  # Batch and shuffle the data

  train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

  2.构建模型

  构建生成器

  def make_generator_model():

  model = tf.keras.Sequential()

  model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))

  model.add(layers.BatchNormalization())

  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((7, 7, 256)))

  assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

  model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))

  assert model.output_shape == (None, 7, 7, 128)

  model.add(layers.BatchNormalization())

  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))

  assert model.output_shape == (None, 14, 14, 64)

  model.add(layers.BatchNormalization())

  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

  assert model.output_shape == (None, 28, 28, 1)

  return model

  生成器生成图片

  generator = make_generator_model()

  noise = tf.random.normal([1, 100])

  generated_image = generator(noise, training=False)

  plt.imshow(generated_image[0, :, :, 0], cmap='gray')

  构造判别器

  def make_discriminator_model():

  model = tf.keras.Sequential()

  model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',

  input_shape=[28, 28, 1]))

  model.add(layers.LeakyReLU())

  model.add(layers.Dropout(0.3))

  model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))

  model.add(layers.LeakyReLU())

  model.add(layers.Dropout(0.3))

  model.add(layers.Flatten())

  model.add(layers.Dense(1))

  return model

  判别器判别

  discriminator = make_discriminator_model()

  decision = discriminator(generated_image)

  print (decision)

  tf.Tensor([[-0.00016926]], shape=(1, 1), dtype=float32)

  3.定义损失函数

  # This method returns a helper function to compute cross entropy loss

  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

  # 判别器损失

  def discriminator_loss(real_output, fake_output):

  real_loss = cross_entropy(tf.ones_like(real_output), real_output)

  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)

  total_loss = real_loss + fake_loss

  return total_loss

  # 生成器损失

  def generator_loss(fake_output):

  return cross_entropy(tf.ones_like(fake_output), fake_output)

  generator_optimizer = tf.keras.optimizers.Adam(1e-4)

  discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

  checkpoint保持

  checkpoint_dir = './training_checkpoints'

  checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

  checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,

  discriminator_optimizer=discriminator_optimizer,

  generator=generator,

  discriminator=discriminator)

  4.训练函数

  EPOCHS = 50

  noise_dim = 100

  num_examples_to_generate = 16

  # We will reuse this seed overtime (so it's easier)

  # to visualize progress in the animated GIF)

  seed = tf.random.normal([num_examples_to_generate, noise_dim])

  训练迭代函数

  # Notice the use of `tf.function`

  # This annotation causes the function to be "compiled".

  @tf.function

  def train_step(images):

  noise = tf.random.normal([BATCH_SIZE, noise_dim])

  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

  generated_images = generator(noise, training=True)

  real_output = discriminator(images, training=True)

  fake_output = discriminator(generated_images, training=True)

  gen_loss = generator_loss(fake_output)

  disc_loss = discriminator_loss(real_output, fake_output)

  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)

  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

  训练函数无锡人流医院哪家好 http://www.ytsg029.com/

  def train(dataset, epochs):

  for epoch in range(epochs):

  start = time.time()

  for image_batch in dataset:

  train_step(image_batch)

  # Produce images for the GIF as we go

  display.clear_output(wait=True)

  generate_and_save_images(generator,

  epoch + 1,

  seed)

  # Save the model every 15 epochs

  if (epoch + 1) % 15 == 0:

  checkpoint.save(file_prefix = checkpoint_prefix)

  print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch

  display.clear_output(wait=True)

  generate_and_save_images(generator,

  epochs,

  seed)

  生成和保存图像

  def generate_and_save_images(model, epoch, test_input):

  # Notice `training` is set to False.

  # This is so all layers run in inference mode (batchnorm).

  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):

  plt.subplot(4, 4, i+1)

  plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')

  plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))

  plt.show()

  5.模型训练

  %%time

  train(train_dataset, EPOCHS)

  png

  CPU times: user 11h 8min 3s, sys: 9min 27s, total: 11h 17min 31s

  Wall time: 3h 13min 51s

  # 生成一张动图

  checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

  def display_image(epoch_no):

  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

  display_image(EPOCHS)

  ## 6.训练过程的动图

  with imageio.get_writer('dcgan.gif', mode='I') as writer:

  filenames = glob.glob('image*.png')

  filenames = sorted(filenames)

  last = -1

  for i,filename in enumerate(filenames):

  frame = 2*(i**0.5)

  if round(frame) > round(last):

  last = frame

  else:

  continue

  image = imageio.imread(filename)

  writer.append_data(image)

  image = imageio.imread(filename)

  writer.append_data(image)

  # A hack to display the GIF inside this notebook

  os.rename('dcgan.gif', 'dcgan.gif.png')

  display.Image(filename="dcgan.gif.png")

posted @ 2019-09-07 14:49  网管布吉岛  阅读(557)  评论(0编辑  收藏  举报