tfgan折腾笔记(二):核心函数详述——gan_model族

定义model的函数有:

1.gan_model

函数原型:

def gan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True)

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的第四个参数generator_inputs,生成网络的返回值是网络的输出(因为是GAN,所以生成器的输出一般是一幅机器生成的图像)。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即此函数的第四个参数(在普通的gan当中,判别器只需要第一个参数。即使不需要第二个参数,也必须显式地定义出第二个参数,只不过定义了之后在判别器函数中可以不使用)。判别器的返回值必须在负无穷到正无穷之间([-inf, +inf])。

real_data:真实图像。一般传入真实图像batch化后的引用。

generator_inputs:生成器的输入。对于vallina gan,是tensor类型的噪声。除此之外,如果是c-gan,还可以传入一个list或tuple作为参数(在下方的“其他说明“里详细说明c-gan(conditional-gan)的情况)。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回一个“GANModel 命名管道”。实际上就是一个由生成器函数、判别器函数、生成的数据、变量空间等东西组成的一个List。这个返回值不需要我们写程序的时候用,就不过多解释了(具体用法见本系列上一篇文档:传送门)。

函数内部实现:

generator_fn和discriminator_fn在gan_model函数里这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs)

# 判别器判断机器生成图片的真实性
discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs)

# 判别器判断真实图片的真实性
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

 

其他说明:

  • gan_model支持conditional-gan。若需要训练c-gan,要通过generator_inputs额外传入标签信息。如:generator_inputs=(noise, one_hot_label)。同时,判别器网络与生成器网络应该按照c-gan论文中的模型重新定义。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

2.infogan_model

函数原型:

def infogan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    unstructured_generator_inputs,
    structured_generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator')

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的unstructrued_generator_inputs与structured_generator_inputs共同组成的列表,列表中的每一项是一个Tensor,生成网络的返回值是生成器的输出。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数应该有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即(unstructrued_generator_inputs与structured_generator_inputs共同组成的列表)。预先定义好的判别器函数的输出应是一个二维Tuple。Tuple的第一维是生成器网络输出层的logits,范围在[-inf, +inf]。Tuple的第二维是分布的列表:此分布的第i个列表元素代表的是第i个structure noise 的分布。

real_data:真实图像。一般传入真实图像batch化后的引用。

unstructured_generator_inputs:Tensor的列表。表示非结构化的noise或条件。

structured_generator_inputs:Tensor的列表。这些Tensor必须与识别器具有较高的相互信息。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

返回值:

返回一个“InfoGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器的输入这样定义:

generator_inputs = (unstructured_generator_inputs + structured_generator_inputs)

 

生成器和判别器这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs)

# 判别器判断机器生成图片的真实性
dis_gen_outputs, predicted_distributions = discriminator_fn(generated_data, generator_inputs)

# 判别器判断真实图片的真实性
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

 

其他说明:

  • 关于生成器和判别器网络模型的搭建,请参照Info-GAN的论文。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

3.acgan_model:

函数原型:

def acgan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    one_hot_labels,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True)

 

参数:

与gan_model中的参数基本一致,除了:

discriminator_fn:预定义的判别器函数应当返回一个二维Tuple。第一维是网络输出层的real或者fake的logits;第二维是分类器的logits。他们两个的范围都应该是[-inf, +inf]。

one_hot_labels:对应于一个batch图像的one_hot_label。

返回值:

返回“AcGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器和判别器这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs)

# 判别器判断机器生成图片的真实性
(discriminator_gen_outputs, discriminator_gen_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(generated_data, generator_inputs))

# 判别器判断真实图片的真实性
(discriminator_real_outputs, discriminator_real_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(real_data, generator_inputs))

 

其他说明:

  • one_hot_labels在此函数内部没有被使用,而是直接通过命名管道(返回值)传递给gan_loss函数(下一篇详细说明)。
  • one_hot_labels与real_data均为batch。

4.cyclegan_model:

函数原型:

def cyclegan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # data X and Y.
    data_x,
    data_y,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    model_x2y_scope='ModelX2Y',
    model_y2x_scope='ModelY2X',
    # Options.
    check_shapes=True)

 

参数:

generator_fn:预先定义好的生成器函数。此生成器的输入有一个参数,与gan_model的generator_fn一样。返回值为生成器网络的输出。

discriminator_fn:预先定义好的判别器函数。与gan_model的discriminator_fn定义一样。

data_x:x域的真实数据。

data_y:y域的真实数据。

generator_scope:与gan_model的generator_scope意义一样。

discriminator_scope:与gan_model的discriminator_scope意义一样。

model_x2y_scope:x->y转换过程的variable_scope。

model_y2x_scope:y->x转换过程的variable_scope。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回“CycleGANModel 命名空间”。

函数内部实现:

此函数实际上调用了gan_model函数,如下所示:

# Create models.
  def _define_partial_model(input_data, output_data):    # 内部函数定义
    return gan_model(
        generator_fn=generator_fn,
        discriminator_fn=discriminator_fn,
        real_data=output_data,
        generator_inputs=input_data,
        generator_scope=generator_scope,
        discriminator_scope=discriminator_scope,
        check_shapes=check_shapes)

  with tf.compat.v1.variable_scope(model_x2y_scope):
    model_x2y = _define_partial_model(data_x, data_y)
  with tf.compat.v1.variable_scope(model_y2x_scope):
    model_y2x = _define_partial_model(data_y, data_x)

  with tf.compat.v1.variable_scope(model_y2x.generator_scope, reuse=True):
    reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
  with tf.compat.v1.variable_scope(model_x2y.generator_scope, reuse=True):
    reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)

  return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
                                   reconstructed_y)

 

其他说明:

5.stargan_model

函数原型:

def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator')

 

参数:

generator_fn:预先定义好的函数的函数名称。函数的输入有两个,应分别为:input、target,返回值是根据inputs和targets由机器生成的图像。inputs的形状应该是(batch, height, width, channel),targets的形状是(batch, num_domain)。返回值有和inputs相同的形状。

discriminator_fn:预先定义好的函数的函数名称。此函数的输入有两个,分别为input和num_domain。返回值是一个Tuple:(`source_prediction`, `domain_prediction`)。`source_prediction`表示预测的图像(真实或生成的)真实度,“ domain_prediction”代表判别器对域分类的预测(真实度)。 `source_prediction`的形状是(batch), `domain_prediction`具有形状(batch,num_domains)。

input_data:Tensor或Tensor组成的列表。代表真实输入的图片。形状是(batch, height, width, channel)。

input_data_domain_label:Tensor或Tensor组成的列表。形状为(batch, num_domains)。表示真实数据相对应的代表域的标签。

generator_scope:与gan_model的此参数意义相同。

discriminator_scope:与gan_model的此参数意义相同。

返回值:

返回“StarGANModel 命名空间”。

函数内部实现:

 函数内部重要代码如下:

  # Transform input_data to random target domains.
  with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
    generated_data_domain_target = generate_stargan_random_domain_target(
        batch_size, num_domains)
    generated_data = generator_fn(input_data, generated_data_domain_target)

  # Transform generated_data back to the original input_data domain.
  with tf.compat.v1.variable_scope(generator_scope, reuse=True):
    reconstructed_data = generator_fn(generated_data, input_data_domain_label)

  # Predict source and domain for the generated_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
    disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
        generated_data, num_domains)

  # Predict source and domain for the input_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
    disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
        input_data, num_domains)

 

其他说明:

 无

posted @ 2020-03-02 17:15  WongWai95  阅读(683)  评论(0编辑  收藏  举报