写入Tfrecord
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | print ( "convert data into tfrecord:train\n" ) out_file_train = "fudan_mtl/data/ace2005/bn_nw.train.tfrecord" writer = tf.python_io.TFRecordWriter(out_file_train) for i in tqdm( range ( len (data_train))): record = tf.train.Example(features = tf.train.Features(feature = { 'word_ids' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_x[i].tostring()])), 'et_ids1' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_et1[i].tostring()])), 'et_ids2' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_et2[i].tostring()])), 'position_ids1' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_p1[i].tostring()])), 'position_ids2' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_p1[i].tostring()])), 'chunks' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_chunks[i].tostring()])), 'spath_ids' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [train_spath[i].tostring()])), 'seq_len' : tf.train.Feature(int64_list = tf.train.Int64List(value = [train_x_len[i]])), 'label' : tf.train.Feature(int64_list = tf.train.Int64List(value = [np.argmax(train_relation[i])])), 'task' : tf.train.Feature(int64_list = tf.train.Int64List(value = [np.int64( 0 )])) })) writer.write(record.SerializeToString()) writer.close() |
解析tfrecord
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | def _parse_tfexample(serialized_example): '''parse serialized tf.train.SequenceExample to tensors context features : label, task sequence features: sentence ''' context_features = { 'label' : tf.FixedLenFeature([], tf.int64), 'task' : tf.FixedLenFeature([], tf.int64), 'seq_len' : tf.FixedLenFeature([], tf.int64)} sequence_features = { 'word_ids' : tf.FixedLenSequenceFeature([], tf.int64), 'et_ids1' : tf.FixedLenSequenceFeature([], tf.int64), 'et_ids2' : tf.FixedLenSequenceFeature([], tf.int64), 'position_ids1' : tf.FixedLenSequenceFeature([], tf.int64), 'position_ids2' : tf.FixedLenSequenceFeature([], tf.int64), 'chunks' : tf.FixedLenSequenceFeature([], tf.int64), 'spath_ids' : tf.FixedLenSequenceFeature([], tf.int64), } context_dict, sequence_dict = tf.parse_single_sequence_example( serialized_example, context_features = context_features, sequence_features = sequence_features) sentence = (sequence_dict[ 'word_ids' ],sequence_dict[ 'et_ids1' ],sequence_dict[ 'et_ids2' ],sequence_dict[ 'position_ids1' ], sequence_dict[ 'position_ids2' ],sequence_dict[ 'chunks' ],sequence_dict[ 'spath_ids' ], context_dict[ 'seq_len' ]) label = context_dict[ 'label' ] task = context_dict[ 'task' ] return task, label, sentence def read_tfrecord(epoch, batch_size): for dataset in DATASETS: train_record_file = os.path.join(OUT_DIR, dataset + '.train.tfrecord' ) test_record_file = os.path.join(OUT_DIR, dataset + '.test.tfrecord' ) train_data = util.read_tfrecord(train_record_file, epoch, batch_size, _parse_tfexample, shuffle = True ) test_data = util.read_tfrecord(test_record_file, epoch, batch_size, _parse_tfexample, shuffle = False ) yield train_data, test_data |
模型中使用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | def build_task_graph( self , data): task_label, labels, sentence = data # sentence = tf.nn.embedding_lookup(self.word_embed, sentence) ########################## word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence # sentence = word_ids ######################### self .word_ids = word_ids self .position_ids1 = position_ids1 self .position_ids2 = position_ids2 self .et_ids1 = et_ids1 self .et_ids2 = et_ids2 self .chunks_ids = chunks self .spath_ids = spath_ids self .seq_len = seq_len sentence = self .add_embedding_layers() |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧