1122
由YouTube8M的视频模型到音频模型转化
youtube8M的接口的参数较为容易设置,首先文件夹的train.py文件
import json import os import time import eval_util import export_model import losses import frame_level_models import video_level_models import readers import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow import app from tensorflow import flags from tensorflow import gfile from tensorflow import logging from tensorflow.python.client import device_lib import utils
这些包含了引入的文件夹的其他写好的py文件和需要用到的库,如果没有的话pip安装即可。
1.模型的保存地址设置:
flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", "The directory to save the model files in.")
需要多提一句的是flags这个模块是用于执行py程序的外部参数设置的交互,具体可以参见tf.app.flags/tf.flags
2.数据集的存储地址指定:
flags.DEFINE_string( "train_data_pattern", "E:/Audio_project/audioset/audioset_v1_embeddings/bal_train/*.tfrecord", "File glob for the training dataset. If the files refer to Frame Level " "features (i.e. tensorflow.SequenceExample), then set --reader_type " "format. The (Sequence)Examples are expected to have 'rgb' byte array " "sequence feature as well as a 'labels' int64 context feature.")
看代码应该知道,这里介绍设置数据集的指向地址,注意地址的分隔符是'' / ” , 其实“\\”也是可以的。
3.特征的名字(这一部分需要特别注意):
flags.DEFINE_string("feature_names", "audio_embedding", "Name of the feature " "to use for training.")
注意要设置成“audio_embedding”
4.帧特征设置
flags.DEFINE_bool( "frame_features", True, "If set, then --train_data_pattern must be frame-level features. " "Otherwise, --train_data_pattern must be aggregated video-level " "features. The model must also be set appropriately (i.e. to read 3D " "batches VS 4D batches.")
5.最重要的就是模型的选择设置
flags.DEFINE_string( "model", "LstmModel", "Which architecture to use for the model. Models are defined " "in models.py.")
主要要改成LstmModel,这个模型的设置其实你可以看frame_level_model.py和video_level_model.py两个文件定义的模型,
我们使用音频的帧特征肯定就要在frame_level_model.py选取模型。