Tensorflow解析tfrecord

1、序列化

#coding:utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
from python_speech_features import logfbank
import scipy.io.wavfile as wav
import struct

OUT_BASE_DIR="/data/asr/duantiantian/data"

# 00000210002     /netdisk1/asr_data/accented/00000210002.wav
# 00000210003     /netdisk1/asr_data/accented/00000210003.wav
wavscp_path=OUT_BASE_DIR+"/wav.scp.bk"
text_path=OUT_BASE_DIR+"/syllables"
symbol_path=OUT_BASE_DIR+"/symbol.txt"

short_file,long_file,wrong_file,oov_file=[],[],[],[]
STRIDE = 2
OUT_DIR = OUT_BASE_DIR+"/tfrecord"
ERROR = OUT_BASE_DIR+"/error_log"
accsum1 = 0.
accsum2 = 0.
accnfrm = 0


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_feature(wavfile):
    rate,sig = wav.read(wavfile)
    if rate!=8000:
        print("*** %s: sample rate is not 8000 ***" % wavfile)
        wrong_file.append(wavfile+'\n')
    feats = logfbank(sig,samplerate=rate,winlen=0.025,winstep=0.01,nfilt=40,nfft=512,lowfreq=64,highfreq=3800,preemph=0.97,dither=1,wintype='povey')
    return feats

def write_feat_to_tfr(id,feats,labels):
    print("dtt--",key,"|",labels)
    outfilename = OUT_DIR + '/' + id + '.tfr'
    feats = np.reshape(feats,[-1])
    feats = feats.astype(np.float32)
    lab = np.array(labels,dtype=np.int32)

    writer = tf.io.TFRecordWriter(outfilename)
    example = tf.train.Example(features = tf.train.Features(feature={'spectr':_bytes_feature(feats.tostring()),'label':_bytes_feature(lab.tostring())}))
    writer.write(example.SerializeToString())
    writer.close()
    return

# model_unit to idx
ref2id = {}
with open(symbol_path,'r') as f:
    for line in f.readlines():
        line = line.strip()
        # print("dtt--",line)
        ref,id = line.strip().split('#')
        ref2id[ref] = int(id)


wavlist = {}
textlist = {}
with open(wavscp_path,'r') as f:
    for line in f:
        line = line.strip()
        id,path = line.split('\t')
        wavlist[id] = path

with open(text_path,'r') as f:
    for line in f.readlines():
        line = line.strip()[:-1]
        if len(line.split(' ')) == 1:     # 如果转录为空,那么就不考虑这一条
            continue
        id,sylls = line.split()
        textlist[id] = sylls

common = {}
filesize = []
for key in wavlist:
    if key in textlist:
        common[key] = (wavlist[key],textlist[key])
        try:
            curr_feature = get_feature(wavlist[key])
        except:
            wrong_file.append(wavlist[key]+'\n')
            continue

        labels = textlist[key].split("#")
        labels = [ref2id[ele] for ele in labels]
        frame_num = curr_feature.shape[0]
        if(frame_num<40):
            print("%s less than 0.4s" % wavlist[key])
            short_file.append(wavlist[key]+'\n')
            continue
        if(frame_num>1500):
            print("%s longer than 15s" % wavlist[key])
            long_file.append(wavlist[key]+'\n')
            continue
        if(frame_num//STRIDE<=len(labels)):
            print("%s label size longer than frame number" % key)
            oov_file.append(key+'\n')
            continue

        #curr_feature = curr_feature.astype(np.float32)
        write_feat_to_tfr(key,curr_feature,labels)
        fsum1 = np.sum(curr_feature, 0)
        fsum2 = np.sum(np.square(curr_feature), 0)
        accsum1 += fsum1
        accsum2 += fsum2
        accnfrm += frame_num
        filesize.append(key+'\t'+str(frame_num*40*4)+'\n')


with open(OUT_BASE_DIR+"/tfr.size",'w') as f:
    f.writelines(filesize)
print("tfr Done")
if len(short_file)>0:
    with open(ERROR+"/short_file",'w') as f:
        f.writelines(short_file)

if len(long_file)>0:
    with open(ERROR+"/long_file",'w') as f:
        f.writelines(long_file)

if len(wrong_file)>0:
    with open(ERROR+"/wrong_file",'w') as f:
        f.writelines(wrong_file)

if len(oov_file)>0:
    with open(ERROR+"/oov_file",'w') as f:
        f.writelines(oov_file)


accsum1 = -accsum1/accnfrm
aux = np.ones(40)
accsum2 = np.divide(aux, np.sqrt(np.subtract((accsum2/accnfrm),np.square(accsum1))))

fmean = np.zeros(40, dtype=float)
fvar = np.zeros(40, dtype=float)
for i in range(40):
    fmean[i] = float(accsum1[i])
    fvar[i] = float(accsum2[i])
print(fmean, fvar)
print('Save done')
#print(curr_feature.shape)
#print(curr_feature[0,:])

 

2、解析

#coding:utf-8
# CNN+FSMN
# Author: Jie Ma

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import re
import time
import os
import errno
import sys
import math
import struct
import ConfigParser

# from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops


# TFR_LIST_FILE="/data/asr/duantiantian/data/tfrecord/eb39a10b280ecdfa04a581da2f002ebf_0114_0002.tfr"
def read_and_decode(TFR_LIST_FILE):
    print("---------------33----------------")
    filename_queue = tf.train.string_input_producer([TFR_LIST_FILE])  # 生成一个queue队列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    raw_example = tf.parse_single_example(serialized_example,
                                       features={
                                           'spectr': tf.FixedLenFeature([], tf.string),
                                            'label': tf.FixedLenFeature([], tf.string),
                                       })  # 将image数据和label取出来
    spectr = raw_example['spectr']
    label = raw_example['label']
    spectr = tf.decode_raw(spectr, tf.float32)
    label = tf.decode_raw(label, tf.int32)
    label = tf.reshape(label, [-1])
    return spectr,label

def parse_exmp(serialized_example):
    raw_example = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'spectr': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string),
        })
    spectr = raw_example['spectr']
    label = raw_example['label']
    label = tf.decode_raw(label, tf.int32)
    label = tf.reshape(label, [-1])
    label = label + 1
    example = tf.decode_raw(spectr, tf.float32)
    example = tf.reshape(example, [-1, FREQ_BIN_NUM, CHANNEL_NUM])
    example_length = tf.shape(example)[0]
    return example, label, example_length


TFR_LIST_FILE="tmp.txt"
BATCH_SIZE=1
FREQ_BIN_NUM=40
CHANNEL_NUM=1


def sparse(example, label, example_length):
    example_length = tf.reshape(example_length, [BATCH_SIZE])
    indices = tf.where(tf.not_equal(tf.cast(label, tf.float32), 0.))
    targets = tf.SparseTensor(indices=indices, values=(tf.gather_nd(label, indices) - 1),
                              dense_shape=tf.cast(tf.shape(label), tf.int64))
    return example, targets, example_length


dataset = tf.data.TextLineDataset(TFR_LIST_FILE)
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=BATCH_SIZE)
dataset = dataset.map(parse_exmp, num_parallel_calls=16)
dataset = dataset.prefetch(buffer_size=10 * BATCH_SIZE)
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None, FREQ_BIN_NUM, CHANNEL_NUM], [None], []))
dataset = dataset.map(sparse, num_parallel_calls=16)
iterator = dataset.make_one_shot_iterator()


# with tf.Session() as sess: #开始一个会话
#
#     sess.run(tf.global_variables_initializer())
#     sess.run(tf.local_variables_initializer())
    # spectr,label=read_and_decode(TFR_LIST_FILE)
    # print("-----------11-----------")
    # init_op = tf.global_variables_initializer()
    # sess.run(init_op)
    # spec, label = sess.run([spectr,label])#在会话中取出image和label
    # print("-----------22-----------")
    # print(label)


def get_batch():
    ele = iterator.get_next()
    spect, label, _ = ele
    return spect,label


spect,label=get_batch()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.make_initializer(dataset))
    try:
        while True:
            print(sess.run([label]))
    except tf.errors.OutOfRangeError:
        print("outOfRange")
posted @ 2022-10-22 22:53  7aughing  阅读(133)  评论(0编辑  收藏  举报