xlnet中文文本分类任务
xlnet中文文本分类任务
数据转化为tfrecord:
-
import tensorflow as tf
-
import sys
-
import six
-
import unicodedata
-
import sentencepiece as spm
-
import collections
-
from textclass import FLAGS
-
-
-
SEG_ID_A = 0
-
SEG_ID_B = 1
-
SEG_ID_CLS = 2
-
SEG_ID_SEP = 3
-
SEG_ID_PAD = 4
-
-
special_symbols = {
-
"<unk>" : 0,
-
"<s>" : 1,
-
"</s>" : 2,
-
"<cls>" : 3,
-
"<sep>" : 4,
-
"<pad>" : 5,
-
"<mask>" : 6,
-
"<eod>" : 7,
-
"<eop>" : 8,
-
}
-
-
VOCAB_SIZE = 32000
-
UNK_ID = special_symbols["<unk>"]
-
CLS_ID = special_symbols["<cls>"]
-
SEP_ID = special_symbols["<sep>"]
-
MASK_ID = special_symbols["<mask>"]
-
EOD_ID = special_symbols["<eod>"]
-
-
-
sp = spm.SentencePieceProcessor()
-
sp.Load(FLAGS.spiece_model_file)
-
-
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
-
while True:
-
total_length = len(tokens_a) + len(tokens_b)
-
if total_length <= max_length:
-
break
-
if len(tokens_a) > len(tokens_b):
-
tokens_a.pop()
-
else:
-
tokens_b.pop()
-
-
def get_class_ids(text,max_seq_length,tokenize_fn):
-
texts = tokenize_fn(text)
-
if len(texts) > max_seq_length - 2:
-
texts = texts[:max_seq_length - 2]
-
tokens = []
-
segment_ids = []
-
for token in texts:
-
tokens.append(token)
-
segment_ids.append(SEG_ID_A)
-
tokens.append(SEP_ID)
-
segment_ids.append(SEG_ID_A)
-
-
tokens.append(CLS_ID)
-
segment_ids.append(SEG_ID_CLS)
-
-
input_ids = tokens
-
input_mask = [0] * len(input_ids)
-
if len(input_ids) < max_seq_length:
-
delta_len = max_seq_length - len(input_ids)
-
input_ids = [0] * delta_len + input_ids
-
input_mask = [1] * delta_len + input_mask
-
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
-
-
assert len(input_ids) == max_seq_length
-
assert len(input_mask) == max_seq_length
-
assert len(segment_ids) == max_seq_length
-
-
return input_ids,input_mask,segment_ids
-
-
-
def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn):
-
tokens_a = tokenize_fn(text_a)
-
tokens_b = tokenize_fn(text_b)
-
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
-
-
tokens = []
-
segment_ids = []
-
for token in tokens_a:
-
tokens.append(token)
-
segment_ids.append(SEG_ID_A)
-
tokens.append(SEP_ID)
-
segment_ids.append(SEG_ID_A)
-
-
for token in tokens_b:
-
tokens.append(token)
-
segment_ids.append(SEG_ID_B)
-
tokens.append(SEP_ID)
-
segment_ids.append(SEG_ID_B)
-
-
tokens.append(CLS_ID)
-
segment_ids.append(SEG_ID_CLS)
-
-
input_ids = tokens
-
input_mask = [0] * len(input_ids)
-
-
if len(input_ids) < max_seq_length:
-
delta_len = max_seq_length - len(input_ids)
-
input_ids = [0] * delta_len + input_ids
-
input_mask = [1] * delta_len + input_mask
-
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
-
-
assert len(input_ids) == max_seq_length
-
assert len(input_mask) == max_seq_length
-
assert len(segment_ids) == max_seq_length
-
-
-
return input_ids,input_mask,segment_ids
-
-
-
-
SPIECE_UNDERLINE = '▁'
-
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
-
if six.PY2 and isinstance(text, unicode):
-
text = text.encode('utf-8')
-
-
if not sample:
-
pieces = sp_model.EncodeAsPieces(text)
-
else:
-
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
-
new_pieces = []
-
for piece in pieces:
-
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
-
cur_pieces = sp_model.EncodeAsPieces(
-
piece[:-1].replace(SPIECE_UNDERLINE, ''))
-
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
-
if len(cur_pieces[0]) == 1:
-
cur_pieces = cur_pieces[1:]
-
else:
-
cur_pieces[0] = cur_pieces[0][1:]
-
cur_pieces.append(piece[-1])
-
new_pieces.extend(cur_pieces)
-
else:
-
new_pieces.append(piece)
-
-
# note(zhiliny): convert back to unicode for py2
-
if six.PY2 and return_unicode:
-
ret_pieces = []
-
for piece in new_pieces:
-
if isinstance(piece, str):
-
piece = piece.decode('utf-8')
-
ret_pieces.append(piece)
-
new_pieces = ret_pieces
-
-
return new_pieces
-
-
-
-
def encode_ids(sp_model, text, sample=False):
-
pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
-
ids = [sp_model.PieceToId(piece) for piece in pieces]
-
return ids
-
-
def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
-
if remove_space:
-
outputs = ' '.join(inputs.strip().split())
-
else:
-
outputs = inputs
-
outputs = outputs.replace("``", '"').replace("''", '"')
-
-
if six.PY2 and isinstance(outputs, str):
-
outputs = outputs.decode('utf-8')
-
-
if not keep_accents:
-
outputs = unicodedata.normalize('NFKD', outputs)
-
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
-
if lower:
-
outputs = outputs.lower()
-
-
return outputs
-
-
-
def tokenize_fn(text):
-
text = preprocess_text(text, lower=True)
-
return encode_ids(sp, text)
-
-
-
def get_vocab(path):
-
maps = collections.defaultdict()
-
i = 0
-
with tf.gfile.GFile(path, "r") as f:
-
for line in f.readlines():
-
maps[line.strip()] = i
-
i = i + 1
-
f.close()
-
return maps
-
-
-
def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn):
-
eachonum = 5000
-
num = 0
-
recordfilenum = 0
-
ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
-
writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
-
with open(inputpath) as f:
-
for text in f.readlines():
-
texts = text.split("\t")
-
content= texts[0].lower().strip()
-
label = vocab.get(texts[1].strip())
-
num = num + 1
-
input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn)
-
if num > eachonum:
-
num = 1
-
recordfilenum = recordfilenum + 1
-
ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
-
writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
-
-
example = tf.train.Example(
-
features=tf.train.Features(
-
feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)),
-
'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)),
-
'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)),
-
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
-
}))
-
serialized = example.SerializeToString()
-
writer.write(serialized)
-
writer.close()
-
f.close()
-
自己写了一个文本分类的类,看下:
-
-
class XlnetReadingClass(object):
-
def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids,
-
input_mask,label,n_class):
-
self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path)
-
self.run_config = xlnet.create_run_config(is_training, True, FLAGS)
-
self.input_ids=tf.transpose(input_ids,[1,0])
-
self.segment_ids = tf.transpose(segment_ids, [1, 0])
-
self.input_mask = tf.transpose(input_mask, [1, 0])
-
-
self.model = xlnet.XLNetModel(
-
xlnet_config=self.xlnet_config,
-
run_config=self.run_config,
-
input_ids=self.input_ids,
-
seg_ids=self.segment_ids,
-
input_mask=self.input_mask)
-
-
cls_scope = FLAGS.cls_scope
-
summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
-
self.per_example_loss, self.logits = modeling.classification_loss(
-
hidden=summary,
-
labels=label,
-
n_class=n_class,
-
initializer=self.model.get_initializer(),
-
scope=cls_scope,
-
return_logits=True)
-
-
self.total_loss = tf.reduce_mean(self.per_example_loss)
-
-
with tf.name_scope("train_op"):
-
-
self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss)
-
-
with tf.name_scope("acc"):
-
one_hot_target = tf.one_hot(label, n_class)
-
self.acc=self.accuracy(self.logits,one_hot_target)
-
-
def accuracy(self,logits, labels):
-
arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
-
arglabels = tf.argmax(tf.squeeze(labels), 1)
-
acc = tf.to_float(tf.equal(arglabels_, arglabels))
-
return tf.reduce_mean(acc)
-
-
-
def main(_):
-
print('Loading config...')
-
-
n_class = 38
-
-
input_path = FLAGS.data_dir + "xlnetreading.tfrecords*"
-
-
print("input_path:", input_path)
-
files = tf.train.match_filenames_once(input_path)
-
-
"""
-
inputs是你数据的输入路径
-
-
"""
-
input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length)
-
model_config_path=FLAGS.model_config_path
-
is_training=False
-
init_checkpoint = FLAGS.init_checkpoint
-
-
-
model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids
-
, segment_ids,input_mask, label_ids, n_class)
-
-
tvars = tf.trainable_variables()
-
-
if init_checkpoint:
-
(assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars,
-
-
init_checkpoint)
-
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
-
print("restore sucess on cpu or gpu")
-
-
session = tf.Session()
-
session.run(tf.global_variables_initializer())
-
session.run(tf.local_variables_initializer())
-
-
print("**** Trainable Variables ****")
-
for var in tvars:
-
if var.name in initialized_variable_names:
-
init_string = ", *INIT_FROM_CKPT*"
-
print("name ={0}, shape = {1}{2}".format(var.name, var.shape,
-
init_string))
-
-
print("xlnet reading class model will start train .........")
-
-
print(session.run(files))
-
saver = tf.train.Saver()
-
coord = tf.train.Coordinator()
-
threads = tf.train.start_queue_runners(coord=coord, sess=session)
-
start_time = time.time()
-
for i in range(8000):
-
_, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc])
-
if i % 100 == 0:
-
end_time = time.time()
-
time_dif = end_time - start_time
-
time_dif = timedelta(seconds=int(round(time_dif)))
-
msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \
-
+ ' Cost: {2} Time:{3} acc:{4}'
-
print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc))
-
start_time = time.time()
-
if i % 500 == 0 and i > 0:
-
saver.save(session, "../exp/reading/model.ckpt", global_step=i)
-
coord.request_stop()
-
coord.join(threads)
-
session.close()