《DenseFuse: A Fusion Approach to Infrared and Visible Images》代码分析
这个代码写的好。
模块主要用到了:os(主要作用文件目录和路径)、scipy(图片读取、保存、缩放,需要依赖PIL)、numpy、tensorflow、time(计算代码块耗时)。
首先判断是训练还是测试(即生产),如果是测试的话是测试video还是图片,测试的图片是否是彩色的;如果是训练的话是否要打印输出详细的信息(即debug)。debug可以作为切分代码块的标记使用。
训练模型在代码上可以分为训练前、训练中、训练后。
训练前:数据的获取、计算图的构建(包括网络结构、损失函数)、计算图参数的初始化、模型保存对象
- 计算图网络结构的重点是卷积核的定义(也就是权重的定义【kernel+bias】)kernel = tf.Variable(tf.truncated_normal(shape, stddev=WEIGHT_INIT_STDDEV), name='kernel'),定义好卷积核之后就可以开始卷积操作了out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID')。
- 计算图中损失函数的定义是难点 train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
训练中:主要是对EPOCHS-N_BATCHES的迭代,还有验证
- 先根据损失函数调整网络参数--->sess.run(train_op, feed_dict={original: original_batch})。
- 然后再计算、存储和打印各部分的损失-->_ssim_loss, _loss, _p_loss = sess.run([ssim_loss, loss, pixel_loss], feed_dict={original: original_batch})
- 因为每迭代一次就会更新一次权重,事实上每更新一次权重就应该利用更新好的权重状态对验证数据集进行一次验证