1122

由YouTube8M的视频模型到音频模型转化

  youtube8M的接口的参数较为容易设置,首先文件夹的train.py文件

import json
import os
import time

import eval_util
import export_model
import losses
import frame_level_models
import video_level_models
import readers
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow import app
from tensorflow import flags
from tensorflow import gfile
from tensorflow import logging
from tensorflow.python.client import device_lib
import utils

  这些包含了引入的文件夹的其他写好的py文件和需要用到的库,如果没有的话pip安装即可。

1.模型的保存地址设置:

flags.DEFINE_string("train_dir", "/tmp/yt8m_model/",
                      "The directory to save the model files in.")

  需要多提一句的是flags这个模块是用于执行py程序的外部参数设置的交互,具体可以参见tf.app.flags/tf.flags

2.数据集的存储地址指定:

flags.DEFINE_string(
      "train_data_pattern", "E:/Audio_project/audioset/audioset_v1_embeddings/bal_train/*.tfrecord",
      "File glob for the training dataset. If the files refer to Frame Level "
      "features (i.e. tensorflow.SequenceExample), then set --reader_type "
      "format. The (Sequence)Examples are expected to have 'rgb' byte array "
      "sequence feature as well as a 'labels' int64 context feature.")

  看代码应该知道,这里介绍设置数据集的指向地址,注意地址的分隔符是'' /  ” , 其实“\\”也是可以的。

3.特征的名字(这一部分需要特别注意):

flags.DEFINE_string("feature_names", "audio_embedding", "Name of the feature "
                      "to use for training.")

   注意要设置成“audio_embedding”

4.帧特征设置

flags.DEFINE_bool(
      "frame_features", True,
      "If set, then --train_data_pattern must be frame-level features. "
      "Otherwise, --train_data_pattern must be aggregated video-level "
      "features. The model must also be set appropriately (i.e. to read 3D "
      "batches VS 4D batches.")

5.最重要的就是模型的选择设置

 flags.DEFINE_string(
      "model", "LstmModel",
      "Which architecture to use for the model. Models are defined "
      "in models.py.")

  主要要改成LstmModel,这个模型的设置其实你可以看frame_level_model.py和video_level_model.py两个文件定义的模型,

我们使用音频的帧特征肯定就要在frame_level_model.py选取模型。

 

 

posted @ 2018-11-22 21:32  陈柯成  阅读(799)  评论(0编辑  收藏  举报