xlnet中文文本分类任务

xlnet中文文本分类任务

 

数据转化为tfrecord:

  1.  
    import tensorflow as tf
  2.  
    import sys
  3.  
    import six
  4.  
    import unicodedata
  5.  
    import sentencepiece as spm
  6.  
    import collections
  7.  
    from textclass import FLAGS
  8.  
     
  9.  
     
  10.  
    SEG_ID_A = 0
  11.  
    SEG_ID_B = 1
  12.  
    SEG_ID_CLS = 2
  13.  
    SEG_ID_SEP = 3
  14.  
    SEG_ID_PAD = 4
  15.  
     
  16.  
    special_symbols = {
  17.  
    "<unk>" : 0,
  18.  
    "<s>" : 1,
  19.  
    "</s>" : 2,
  20.  
    "<cls>" : 3,
  21.  
    "<sep>" : 4,
  22.  
    "<pad>" : 5,
  23.  
    "<mask>" : 6,
  24.  
    "<eod>" : 7,
  25.  
    "<eop>" : 8,
  26.  
    }
  27.  
     
  28.  
    VOCAB_SIZE = 32000
  29.  
    UNK_ID = special_symbols["<unk>"]
  30.  
    CLS_ID = special_symbols["<cls>"]
  31.  
    SEP_ID = special_symbols["<sep>"]
  32.  
    MASK_ID = special_symbols["<mask>"]
  33.  
    EOD_ID = special_symbols["<eod>"]
  34.  
     
  35.  
     
  36.  
    sp = spm.SentencePieceProcessor()
  37.  
    sp.Load(FLAGS.spiece_model_file)
  38.  
     
  39.  
    def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  40.  
    while True:
  41.  
    total_length = len(tokens_a) + len(tokens_b)
  42.  
    if total_length <= max_length:
  43.  
    break
  44.  
    if len(tokens_a) > len(tokens_b):
  45.  
    tokens_a.pop()
  46.  
    else:
  47.  
    tokens_b.pop()
  48.  
     
  49.  
    def get_class_ids(text,max_seq_length,tokenize_fn):
  50.  
    texts = tokenize_fn(text)
  51.  
    if len(texts) > max_seq_length - 2:
  52.  
    texts = texts[:max_seq_length - 2]
  53.  
    tokens = []
  54.  
    segment_ids = []
  55.  
    for token in texts:
  56.  
    tokens.append(token)
  57.  
    segment_ids.append(SEG_ID_A)
  58.  
    tokens.append(SEP_ID)
  59.  
    segment_ids.append(SEG_ID_A)
  60.  
     
  61.  
    tokens.append(CLS_ID)
  62.  
    segment_ids.append(SEG_ID_CLS)
  63.  
     
  64.  
    input_ids = tokens
  65.  
    input_mask = [0] * len(input_ids)
  66.  
    if len(input_ids) < max_seq_length:
  67.  
    delta_len = max_seq_length - len(input_ids)
  68.  
    input_ids = [0] * delta_len + input_ids
  69.  
    input_mask = [1] * delta_len + input_mask
  70.  
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
  71.  
     
  72.  
    assert len(input_ids) == max_seq_length
  73.  
    assert len(input_mask) == max_seq_length
  74.  
    assert len(segment_ids) == max_seq_length
  75.  
     
  76.  
    return input_ids,input_mask,segment_ids
  77.  
     
  78.  
     
  79.  
    def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn):
  80.  
    tokens_a = tokenize_fn(text_a)
  81.  
    tokens_b = tokenize_fn(text_b)
  82.  
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
  83.  
     
  84.  
    tokens = []
  85.  
    segment_ids = []
  86.  
    for token in tokens_a:
  87.  
    tokens.append(token)
  88.  
    segment_ids.append(SEG_ID_A)
  89.  
    tokens.append(SEP_ID)
  90.  
    segment_ids.append(SEG_ID_A)
  91.  
     
  92.  
    for token in tokens_b:
  93.  
    tokens.append(token)
  94.  
    segment_ids.append(SEG_ID_B)
  95.  
    tokens.append(SEP_ID)
  96.  
    segment_ids.append(SEG_ID_B)
  97.  
     
  98.  
    tokens.append(CLS_ID)
  99.  
    segment_ids.append(SEG_ID_CLS)
  100.  
     
  101.  
    input_ids = tokens
  102.  
    input_mask = [0] * len(input_ids)
  103.  
     
  104.  
    if len(input_ids) < max_seq_length:
  105.  
    delta_len = max_seq_length - len(input_ids)
  106.  
    input_ids = [0] * delta_len + input_ids
  107.  
    input_mask = [1] * delta_len + input_mask
  108.  
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
  109.  
     
  110.  
    assert len(input_ids) == max_seq_length
  111.  
    assert len(input_mask) == max_seq_length
  112.  
    assert len(segment_ids) == max_seq_length
  113.  
     
  114.  
     
  115.  
    return input_ids,input_mask,segment_ids
  116.  
     
  117.  
     
  118.  
     
  119.  
    SPIECE_UNDERLINE = '▁'
  120.  
    def encode_pieces(sp_model, text, return_unicode=True, sample=False):
  121.  
    if six.PY2 and isinstance(text, unicode):
  122.  
    text = text.encode('utf-8')
  123.  
     
  124.  
    if not sample:
  125.  
    pieces = sp_model.EncodeAsPieces(text)
  126.  
    else:
  127.  
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  128.  
    new_pieces = []
  129.  
    for piece in pieces:
  130.  
    if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
  131.  
    cur_pieces = sp_model.EncodeAsPieces(
  132.  
    piece[:-1].replace(SPIECE_UNDERLINE, ''))
  133.  
    if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
  134.  
    if len(cur_pieces[0]) == 1:
  135.  
    cur_pieces = cur_pieces[1:]
  136.  
    else:
  137.  
    cur_pieces[0] = cur_pieces[0][1:]
  138.  
    cur_pieces.append(piece[-1])
  139.  
    new_pieces.extend(cur_pieces)
  140.  
    else:
  141.  
    new_pieces.append(piece)
  142.  
     
  143.  
    # note(zhiliny): convert back to unicode for py2
  144.  
    if six.PY2 and return_unicode:
  145.  
    ret_pieces = []
  146.  
    for piece in new_pieces:
  147.  
    if isinstance(piece, str):
  148.  
    piece = piece.decode('utf-8')
  149.  
    ret_pieces.append(piece)
  150.  
    new_pieces = ret_pieces
  151.  
     
  152.  
    return new_pieces
  153.  
     
  154.  
     
  155.  
     
  156.  
    def encode_ids(sp_model, text, sample=False):
  157.  
    pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
  158.  
    ids = [sp_model.PieceToId(piece) for piece in pieces]
  159.  
    return ids
  160.  
     
  161.  
    def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
  162.  
    if remove_space:
  163.  
    outputs = ' '.join(inputs.strip().split())
  164.  
    else:
  165.  
    outputs = inputs
  166.  
    outputs = outputs.replace("``", '"').replace("''", '"')
  167.  
     
  168.  
    if six.PY2 and isinstance(outputs, str):
  169.  
    outputs = outputs.decode('utf-8')
  170.  
     
  171.  
    if not keep_accents:
  172.  
    outputs = unicodedata.normalize('NFKD', outputs)
  173.  
    outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
  174.  
    if lower:
  175.  
    outputs = outputs.lower()
  176.  
     
  177.  
    return outputs
  178.  
     
  179.  
     
  180.  
    def tokenize_fn(text):
  181.  
    text = preprocess_text(text, lower=True)
  182.  
    return encode_ids(sp, text)
  183.  
     
  184.  
     
  185.  
    def get_vocab(path):
  186.  
    maps = collections.defaultdict()
  187.  
    i = 0
  188.  
    with tf.gfile.GFile(path, "r") as f:
  189.  
    for line in f.readlines():
  190.  
    maps[line.strip()] = i
  191.  
    i = i + 1
  192.  
    f.close()
  193.  
    return maps
  194.  
     
  195.  
     
  196.  
    def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn):
  197.  
    eachonum = 5000
  198.  
    num = 0
  199.  
    recordfilenum = 0
  200.  
    ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
  201.  
    writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
  202.  
    with open(inputpath) as f:
  203.  
    for text in f.readlines():
  204.  
    texts = text.split("\t")
  205.  
    content= texts[0].lower().strip()
  206.  
    label = vocab.get(texts[1].strip())
  207.  
    num = num + 1
  208.  
    input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn)
  209.  
    if num > eachonum:
  210.  
    num = 1
  211.  
    recordfilenum = recordfilenum + 1
  212.  
    ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
  213.  
    writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
  214.  
     
  215.  
    example = tf.train.Example(
  216.  
    features=tf.train.Features(
  217.  
    feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)),
  218.  
    'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)),
  219.  
    'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)),
  220.  
    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
  221.  
    }))
  222.  
    serialized = example.SerializeToString()
  223.  
    writer.write(serialized)
  224.  
    writer.close()
  225.  
    f.close()
  226.  
     
自己写了一个文本分类的类,看下:
  1.  
     
  2.  
    class XlnetReadingClass(object):
  3.  
    def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids,
  4.  
    input_mask,label,n_class):
  5.  
    self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path)
  6.  
    self.run_config = xlnet.create_run_config(is_training, True, FLAGS)
  7.  
    self.input_ids=tf.transpose(input_ids,[1,0])
  8.  
    self.segment_ids = tf.transpose(segment_ids, [1, 0])
  9.  
    self.input_mask = tf.transpose(input_mask, [1, 0])
  10.  
     
  11.  
    self.model = xlnet.XLNetModel(
  12.  
    xlnet_config=self.xlnet_config,
  13.  
    run_config=self.run_config,
  14.  
    input_ids=self.input_ids,
  15.  
    seg_ids=self.segment_ids,
  16.  
    input_mask=self.input_mask)
  17.  
     
  18.  
    cls_scope = FLAGS.cls_scope
  19.  
    summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
  20.  
    self.per_example_loss, self.logits = modeling.classification_loss(
  21.  
    hidden=summary,
  22.  
    labels=label,
  23.  
    n_class=n_class,
  24.  
    initializer=self.model.get_initializer(),
  25.  
    scope=cls_scope,
  26.  
    return_logits=True)
  27.  
     
  28.  
    self.total_loss = tf.reduce_mean(self.per_example_loss)
  29.  
     
  30.  
    with tf.name_scope("train_op"):
  31.  
     
  32.  
    self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss)
  33.  
     
  34.  
    with tf.name_scope("acc"):
  35.  
    one_hot_target = tf.one_hot(label, n_class)
  36.  
    self.acc=self.accuracy(self.logits,one_hot_target)
  37.  
     
  38.  
    def accuracy(self,logits, labels):
  39.  
    arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
  40.  
    arglabels = tf.argmax(tf.squeeze(labels), 1)
  41.  
    acc = tf.to_float(tf.equal(arglabels_, arglabels))
  42.  
    return tf.reduce_mean(acc)
  43.  
     
  44.  
     
  45.  
    def main(_):
  46.  
    print('Loading config...')
  47.  
     
  48.  
    n_class = 38
  49.  
     
  50.  
    input_path = FLAGS.data_dir + "xlnetreading.tfrecords*"
  51.  
     
  52.  
    print("input_path:", input_path)
  53.  
    files = tf.train.match_filenames_once(input_path)
  54.  
     
  55.  
    """
  56.  
    inputs是你数据的输入路径
  57.  
     
  58.  
    """
  59.  
    input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length)
  60.  
    model_config_path=FLAGS.model_config_path
  61.  
    is_training=False
  62.  
    init_checkpoint = FLAGS.init_checkpoint
  63.  
     
  64.  
     
  65.  
    model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids
  66.  
    , segment_ids,input_mask, label_ids, n_class)
  67.  
     
  68.  
    tvars = tf.trainable_variables()
  69.  
     
  70.  
    if init_checkpoint:
  71.  
    (assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars,
  72.  
     
  73.  
    init_checkpoint)
  74.  
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  75.  
    print("restore sucess on cpu or gpu")
  76.  
     
  77.  
    session = tf.Session()
  78.  
    session.run(tf.global_variables_initializer())
  79.  
    session.run(tf.local_variables_initializer())
  80.  
     
  81.  
    print("**** Trainable Variables ****")
  82.  
    for var in tvars:
  83.  
    if var.name in initialized_variable_names:
  84.  
    init_string = ", *INIT_FROM_CKPT*"
  85.  
    print("name ={0}, shape = {1}{2}".format(var.name, var.shape,
  86.  
    init_string))
  87.  
     
  88.  
    print("xlnet reading class model will start train .........")
  89.  
     
  90.  
    print(session.run(files))
  91.  
    saver = tf.train.Saver()
  92.  
    coord = tf.train.Coordinator()
  93.  
    threads = tf.train.start_queue_runners(coord=coord, sess=session)
  94.  
    start_time = time.time()
  95.  
    for i in range(8000):
  96.  
    _, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc])
  97.  
    if i % 100 == 0:
  98.  
    end_time = time.time()
  99.  
    time_dif = end_time - start_time
  100.  
    time_dif = timedelta(seconds=int(round(time_dif)))
  101.  
    msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \
  102.  
    + ' Cost: {2} Time:{3} acc:{4}'
  103.  
    print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc))
  104.  
    start_time = time.time()
  105.  
    if i % 500 == 0 and i > 0:
  106.  
    saver.save(session, "../exp/reading/model.ckpt", global_step=i)
  107.  
    coord.request_stop()
  108.  
    coord.join(threads)
  109.  
    session.close()

 

posted on 2019-09-02 20:49  曹明  阅读(3023)  评论(0编辑  收藏  举报