PC 14+15+16

Long time no see!
我发现虽然大四了,但是课还是挺多的(体育x2 实验x2 专业课x2 英语x1),每天倒腾的时间也不长。所以很久没更新了,再者前几天在学linux的一点点知识,比如shell,vim还有tmux。学艺不精,又不知从何学起,不过倒是可以看懂别人写的shell脚本了,也算进步了吧(微笑)。

最近遇到的一个问题是tf.estimator的使用,tfnet代码里面
tfnet_est.predict
如果输入数据集里面只有low_wav会报错,还必须得输入(high_wav,low_wav)对。很奇怪,因为predict应该只有low_wav,在逻辑上才是正确的。


TFNetEstimator
是继承tf.estimator.Estimator的一个类,改了model
我们现在需要调用TFNetEstimator.predict
input_fn=lambda: get_dummy_dataset().make_one_shot_iterator().get_next()是没有问题的
而用

        dset = ds.single_file_dataset(LQ_AUDIO_FILE,)
        #RunConfig for more more printing since we are only training for very few steps
        config = tf.estimator.RunConfig(log_step_count_steps=1)

        tfnet_est = TFNetEstimator(**nets.default_net(), config=config,
                                   model_dir=DUMMY_MODEL_PATH)

        preds = tfnet_est.predict(input_fn=lambda: dset.make_one_shot_iterator().get_next())

会报错。下面看下这两个dset的细节:
1)ds.single_file_dataset:

def single_file_dataset(filename, upsample_rate=2, seg_length=8192, batchsize=16, **kwargs):
    """Loads a single audio file and process it in order, use for prediction"""
    DEBUG("Ignored args: " + str(kwargs))
    audio_in = filters.upsample(_load_wav(filename), upsample_rate)
    audio_len, channels = audio_in.shape
    padlen = seg_length - audio_len%seg_length
    audio_padded = np.pad(audio_in, [(0, padlen), (0, 0)], 'constant')
    audio_segs = audio_padded.reshape((-1, seg_length, channels))

    def _gen():
        for seg in audio_segs:
            yield seg

    dset = tf.data.Dataset.from_generator(_gen,
                                          output_types=tf.float32,
                                          output_shapes=[seg_length, channels])
    dset = dset.batch(batchsize)
    return dset

2)get_dummy_dataset:

def get_dummy_dataset(length=8192, channels=1, count=16,
                      batchsize=4, repeat=200,
                      drop_remainder=True
                     ):
    """Dummy dataset generator for use in unit tests"""
    dummy_hr = np.array(np.linspace(0, 1, length)[:, np.newaxis], dtype=np.float32)
    dummy_hr = np.hstack([dummy_hr for _ in range(channels)])
    dummy_lr = dummy_hr.copy()
    dummy_lr[1::2] = 0

    dummy_train = [(dummy_lr.copy(), dummy_hr.copy()) for _ in range(count)]

    dummy_dset = tf.data.Dataset.from_generator(lambda: ((l, h) for l, h in dummy_train),
                                                output_types=(tf.float32, tf.float32),
                                                output_shapes=([length, channels],
                                                               [length, channels]))

    #16 samples per epoch, 2 epochs, batch size 4 -> 8 iterations
    dummy_dset = dummy_dset.repeat(repeat).batch(batchsize, drop_remainder=drop_remainder)
    return dummy_dset

https://tensorexamples.com/2020/07/27/Using-the-tf.data.Dataset.html
先看看Dataset

posted @ 2022-10-29 10:58  prettysky  阅读(23)  评论(0编辑  收藏  举报