Tensorflow解析tfrecord
1、序列化
#coding:utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from python_speech_features import logfbank
import scipy.io.wavfile as wav
import struct
OUT_BASE_DIR="/data/asr/duantiantian/data"
# 00000210002 /netdisk1/asr_data/accented/00000210002.wav
# 00000210003 /netdisk1/asr_data/accented/00000210003.wav
wavscp_path=OUT_BASE_DIR+"/wav.scp.bk"
text_path=OUT_BASE_DIR+"/syllables"
symbol_path=OUT_BASE_DIR+"/symbol.txt"
short_file,long_file,wrong_file,oov_file=[],[],[],[]
STRIDE = 2
OUT_DIR = OUT_BASE_DIR+"/tfrecord"
ERROR = OUT_BASE_DIR+"/error_log"
accsum1 = 0.
accsum2 = 0.
accnfrm = 0
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def get_feature(wavfile):
rate,sig = wav.read(wavfile)
if rate!=8000:
print("*** %s: sample rate is not 8000 ***" % wavfile)
wrong_file.append(wavfile+'\n')
feats = logfbank(sig,samplerate=rate,winlen=0.025,winstep=0.01,nfilt=40,nfft=512,lowfreq=64,highfreq=3800,preemph=0.97,dither=1,wintype='povey')
return feats
def write_feat_to_tfr(id,feats,labels):
print("dtt--",key,"|",labels)
outfilename = OUT_DIR + '/' + id + '.tfr'
feats = np.reshape(feats,[-1])
feats = feats.astype(np.float32)
lab = np.array(labels,dtype=np.int32)
writer = tf.io.TFRecordWriter(outfilename)
example = tf.train.Example(features = tf.train.Features(feature={'spectr':_bytes_feature(feats.tostring()),'label':_bytes_feature(lab.tostring())}))
writer.write(example.SerializeToString())
writer.close()
return
# model_unit to idx
ref2id = {}
with open(symbol_path,'r') as f:
for line in f.readlines():
line = line.strip()
# print("dtt--",line)
ref,id = line.strip().split('#')
ref2id[ref] = int(id)
wavlist = {}
textlist = {}
with open(wavscp_path,'r') as f:
for line in f:
line = line.strip()
id,path = line.split('\t')
wavlist[id] = path
with open(text_path,'r') as f:
for line in f.readlines():
line = line.strip()[:-1]
if len(line.split(' ')) == 1: # 如果转录为空,那么就不考虑这一条
continue
id,sylls = line.split()
textlist[id] = sylls
common = {}
filesize = []
for key in wavlist:
if key in textlist:
common[key] = (wavlist[key],textlist[key])
try:
curr_feature = get_feature(wavlist[key])
except:
wrong_file.append(wavlist[key]+'\n')
continue
labels = textlist[key].split("#")
labels = [ref2id[ele] for ele in labels]
frame_num = curr_feature.shape[0]
if(frame_num<40):
print("%s less than 0.4s" % wavlist[key])
short_file.append(wavlist[key]+'\n')
continue
if(frame_num>1500):
print("%s longer than 15s" % wavlist[key])
long_file.append(wavlist[key]+'\n')
continue
if(frame_num//STRIDE<=len(labels)):
print("%s label size longer than frame number" % key)
oov_file.append(key+'\n')
continue
#curr_feature = curr_feature.astype(np.float32)
write_feat_to_tfr(key,curr_feature,labels)
fsum1 = np.sum(curr_feature, 0)
fsum2 = np.sum(np.square(curr_feature), 0)
accsum1 += fsum1
accsum2 += fsum2
accnfrm += frame_num
filesize.append(key+'\t'+str(frame_num*40*4)+'\n')
with open(OUT_BASE_DIR+"/tfr.size",'w') as f:
f.writelines(filesize)
print("tfr Done")
if len(short_file)>0:
with open(ERROR+"/short_file",'w') as f:
f.writelines(short_file)
if len(long_file)>0:
with open(ERROR+"/long_file",'w') as f:
f.writelines(long_file)
if len(wrong_file)>0:
with open(ERROR+"/wrong_file",'w') as f:
f.writelines(wrong_file)
if len(oov_file)>0:
with open(ERROR+"/oov_file",'w') as f:
f.writelines(oov_file)
accsum1 = -accsum1/accnfrm
aux = np.ones(40)
accsum2 = np.divide(aux, np.sqrt(np.subtract((accsum2/accnfrm),np.square(accsum1))))
fmean = np.zeros(40, dtype=float)
fvar = np.zeros(40, dtype=float)
for i in range(40):
fmean[i] = float(accsum1[i])
fvar[i] = float(accsum2[i])
print(fmean, fvar)
print('Save done')
#print(curr_feature.shape)
#print(curr_feature[0,:])
2、解析
#coding:utf-8
# CNN+FSMN
# Author: Jie Ma
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import re
import time
import os
import errno
import sys
import math
import struct
import ConfigParser
# from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
# TFR_LIST_FILE="/data/asr/duantiantian/data/tfrecord/eb39a10b280ecdfa04a581da2f002ebf_0114_0002.tfr"
def read_and_decode(TFR_LIST_FILE):
print("---------------33----------------")
filename_queue = tf.train.string_input_producer([TFR_LIST_FILE]) # 生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
raw_example = tf.parse_single_example(serialized_example,
features={
'spectr': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string),
}) # 将image数据和label取出来
spectr = raw_example['spectr']
label = raw_example['label']
spectr = tf.decode_raw(spectr, tf.float32)
label = tf.decode_raw(label, tf.int32)
label = tf.reshape(label, [-1])
return spectr,label
def parse_exmp(serialized_example):
raw_example = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'spectr': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string),
})
spectr = raw_example['spectr']
label = raw_example['label']
label = tf.decode_raw(label, tf.int32)
label = tf.reshape(label, [-1])
label = label + 1
example = tf.decode_raw(spectr, tf.float32)
example = tf.reshape(example, [-1, FREQ_BIN_NUM, CHANNEL_NUM])
example_length = tf.shape(example)[0]
return example, label, example_length
TFR_LIST_FILE="tmp.txt"
BATCH_SIZE=1
FREQ_BIN_NUM=40
CHANNEL_NUM=1
def sparse(example, label, example_length):
example_length = tf.reshape(example_length, [BATCH_SIZE])
indices = tf.where(tf.not_equal(tf.cast(label, tf.float32), 0.))
targets = tf.SparseTensor(indices=indices, values=(tf.gather_nd(label, indices) - 1),
dense_shape=tf.cast(tf.shape(label), tf.int64))
return example, targets, example_length
dataset = tf.data.TextLineDataset(TFR_LIST_FILE)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=BATCH_SIZE)
dataset = dataset.map(parse_exmp, num_parallel_calls=16)
dataset = dataset.prefetch(buffer_size=10 * BATCH_SIZE)
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None, FREQ_BIN_NUM, CHANNEL_NUM], [None], []))
dataset = dataset.map(sparse, num_parallel_calls=16)
iterator = dataset.make_one_shot_iterator()
# with tf.Session() as sess: #开始一个会话
#
# sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
# spectr,label=read_and_decode(TFR_LIST_FILE)
# print("-----------11-----------")
# init_op = tf.global_variables_initializer()
# sess.run(init_op)
# spec, label = sess.run([spectr,label])#在会话中取出image和label
# print("-----------22-----------")
# print(label)
def get_batch():
ele = iterator.get_next()
spect, label, _ = ele
return spect,label
spect,label=get_batch()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.make_initializer(dataset))
try:
while True:
print(sess.run([label]))
except tf.errors.OutOfRangeError:
print("outOfRange")