flags 是 tensorflow 定义参数的比较正规的方式,没什么特别的,直接上代码
flags = tf.app.flags flags.DEFINE_integer("epoch", 1000, "Epoch to train [25]") flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") flags.DEFINE_integer("train_size", 256, "The size of train images [np.inf]") flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]") flags.DEFINE_boolean("train", True, "True for training, False for testing [False]") FLAGS = flags.FLAGS def main(_): print(FLAGS.epoch) ### 1000 if __name__ == '__main__': tf.app.run()
就是这么简单
需要注意的是:
1. tf.app.flags 定义的参数会自动传送给 tf.app.run
2. 在执行 tf.app.run 时必须有个 main 函数,且 main 函数必须有一个参数
3. 可用命令行的方式重置参数
4. 底层实现是 argparse,用法雷同
参考资料:
https://www.360kuai.com/pc/9a31be7ed9823b890?cota=4&kuai_so=1&tj_url=so_rec&sign=360_57c3bbd1&refer_scene=so_1
https://blog.csdn.net/u014084019/article/details/78586390