mxnet 中变长输入的实现

mxnet中,每个输入到RNN的batch 长度可不一样,但是mxnet中要求 batch内的长度是一样的。这里采用的方法是,利用 gluonnlp让每个batch内 seq的长度尽量一样。

假设进行的是 单输入的分类任务

import gluonnlp as nlp
from mxnet.gluon import data as gdata

'''
获取句子长度
'''
train_data_lengths=list()
for q in querys: # querys是list数据,是载入的序列数据,querys每个元素是一个样本,每个样本是一个list,list的元素是 单词的编号
    train_data_lengths.append(len(q))

'''
准备数据
'''
train_gdata=gdata.ArrayDataset(querys, labels) # labels 是list数据,是类别标签,labels每个元素是一个样本的标签编号 

'''
准备处理工具 batchify_fn
nlp.data.batchify.Tuple 把工具整合起来
nlp.data.batchify.Pad 处理 序列数据的pad
nlp.data.batchify.Stack 把标签数据整理成 mxnet.ndarray 
'''
batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(axis=0,pad_val=0),  # 处理 序列数据的pad,得到的是mxnet.ndarray
    nlp.data.batchify.Stack() # 处理标签数据,得到的是mxnet.ndarray
) 

'''
准备batch采样工具 batch_sampler
'''
batch_size=16
# batch_sampler = nlp.data.sampler.FixedBucketSampler(train_data_lengths,batch_size=batch_size,num_buckets=10,ratio=0.5,shuffle=True) # 不一定能得到 大小为batch_size的batch
batch_sampler=nlp.data.sampler.SortedBucketSampler(train_data_lengths,batch_size=batch_size,mult=100,shuffle=True) # 一定能得到 大小为batch_size的batch
    
'''
准备DataLoader
'''
train_dataloader = gluon.data.DataLoader(train_gdata,batch_sampler=batch_sampler,batchify_fn=batchify_fn)
step=0
for query,label in train_dataloader:
     print query.shape,label.shape

假设进行的是 多输入的生成任务

import gluonnlp as nlp
from mxnet.gluon import data as gdata

'''
获取句子长度  !需要注意, train_data_lengths的每个元素需要换成 tuple of int or list of int,表示一个样本中多个序列的长度
'''
train_data_lengths=list()
# querys、reply1s、reply2s、target_replies都是list数据,是序列数据,每个元素是一个样本,每个样本是一个list,list的元素是 单词的编号
for i in range(len(train_data.queries)):
    length1=len(querys[i])

    length2=len(reply1s[i])
    length3=len(reply2s[i])
    length4=len(target_replies[i])

    train_data_lengths.append((length1, length2, length3, length4))

'''
准备数据
'''
train_gdata=gdata.ArrayDataset(querys, reply1s, reply2s, target_replies) # labels 是list数据,是类别标签,labels每个元素是一个样本的标签编号 

'''
准备处理工具 batchify_fn
nlp.data.batchify.Tuple 把工具整合起来
nlp.data.batchify.Pad 处理 序列数据的pad
nlp.data.batchify.Stack 把标签数据整理成 mxnet.ndarray 
'''
batchify_fn = nlp.data.batchify.Tuple(  # 有多少个序列数据,就有多少个nlp.data.batchify.Pad
    nlp.data.batchify.Pad(axis=0,pad_val=0),  # 处理 序列数据的pad,得到的是mxnet.ndarray
    nlp.data.batchify.Pad(axis=0,pad_val=0),  # 处理 序列数据的pad,得到的是mxnet.ndarray
    nlp.data.batchify.Pad(axis=0,pad_val=0),  # 处理 序列数据的pad,得到的是mxnet.ndarray
    nlp.data.batchify.Pad(axis=0,pad_val=0)  # 处理 序列数据的pad,得到的是mxnet.ndarray
) 

'''
准备batch采样工具 batch_sampler
'''
batch_size=16
# batch_sampler = nlp.data.sampler.FixedBucketSampler(train_data_lengths,batch_size=batch_size,num_buckets=10,ratio=0.5,shuffle=True) # 不一定能得到 大小为batch_size的batch
batch_sampler=nlp.data.sampler.SortedBucketSampler(train_data_lengths,batch_size=batch_size,mult=100,shuffle=True) # 一定能得到 大小为batch_size的batch

'''
准备DataLoader
'''
train_dataloader = gluon.data.DataLoader(train_gdata,batch_sampler=batch_sampler,batchify_fn=batchify_fn)
step=0
for query, reply1,reply2,target_reply in train_dataloader:
     print query.shape, reply1.shape, reply2.shape, target_reply.shape
posted @ 2019-07-23 19:41  hui_lyh  阅读(529)  评论(0编辑  收藏  举报