tensorflow版本的tansformer训练IWSLT数据集
代码来源:https://github.com/Kyubyong/transformer
1、git clone https://github.com/Kyubyong/transformer.git
2、pip install sentencepiece
3、下载数据集
进入到tansformer目录下,输入:sh download.sh
运行成功之后,会有这么一些文件:
de-en.de.xml中内容大致是这个样子的:
<?xml version="1.0" encoding="UTF-8"?> <mteval> <srcset setid="iwslt2016-dev2010" srclang="german"> <doc docid="69" genre="lectures"> <url>http://www.ted.com/talks/lang/de/wade_davis_on_endangered_cultures.html</url> <description>Mit atemberaubenden Fotos und Geschichten feiert der National Geographic- Forschungsreisende Wade Davis, die außergewöhnliche Vielfalt der Ureinwohner der Welt, welche in alarmierender Anzahl von unserem Planeten verschwinden.</description> <keywords>anthropology,culture,environment,film,global issues,language,photography</keywords> <talkid>69</talkid> <title>Wade Davis über gefährdete Kulturen</title> <seg id="1"> Wissen Sie, eines der großen Vernügen beim Reisen und eine der Freuden bei der ethnographischen Forschung ist, gemeinsam mit den Menschen zu leben, die sich noch an die alten Tage erinnern können. Die ihre Vergangenheit noch immer im Wind spüren, sie auf vom Regen geglätteten Steinen berühren, sie in den bitteren Blättern der Pflanzen schmecken. </seg> <seg id="2"> Einfach das Wissen, dass Jaguar-Schamanen noch immer jenseits der Milchstraße reisen oder die Bedeutung der Mythen der Ältesten der Inuit noch voller Bedeutung sind, oder dass im Himalaya die Buddhisten noch immer den Atem des Dharma verfolgen, bedeutet, sich die zentrale Offenbarung der Anthropologie ins Gedächtnis zu rufen, das ist der Gedanke, dass die Welt, in der wir leben, nicht in einem absoluten Sinn existiert, sondern nur als ein Modell der Realität, als eine Folge einer Gruppe von bestimmten Möglichkeiten der Anpassung die unsere Ahnen, wenngleich erfolgreich, vor vielen Generationen wählten. </seg> <seg id="3"> Und natürlich teilen wir alle dieselben Anpassungsnotwendigkeiten. </seg> <seg id="4"> Wir werden alle geboren. Wir bringen Kinder zur Welt. </seg> <seg id="5"> Wir durchlaufen Initiationsrituale. </seg> <seg id="6"> Wir müssen uns mit der unaufhaltsamen Trennung durch den Tod auseinandersetzen und somit sollte es uns nicht überraschen, dass wir alle singen, tanzen und und Kunst hervorbringen. </seg> <seg id="7"> Aber interessant ist der einzigartige Tonfall des Liedes, der Rhythmus des Tanzes in jeder Kultur. </seg> <seg id="8"> Dabei spielt es keine Rolle, ob es sich um die Penan in den Wäldern von Borneo handelt, oder die Voodoo-Akolythen in Haiti, oder die Krieger in der Kaisut-Wüste von Nordkenia, die Curanderos in den Anden, oder eine Karawanserei mitten in der Sahara. Dies ist zufällig der Kollege, mit dem ich vor einem Monat in die Wüste gereist bin. Oder selbst ein Yak-Hirte an den Hängen des Qomolangma, Everest, der Gottmutter der Welt. </seg> <seg id="9"> All diese Menschen lehren uns, dass es noch andere Existenzmöglichkeiten, andere Denkweisen, andere Wege zur Orientierung auf der Erde gibt. </seg> <seg id="10"> Und das ist eine Vorstellung, die, wenn man darüber nachdenkt, einen nur mit Hoffnung erfüllen kann. </seg> <seg id="11"> Zusammen bilden die unzähligen Kulturen der Welt ein Netz aus spirituellem und kulturellem Leben, das die Erde umhüllt und für das Wohl der Erde genauso wichtig ist, wie das biologische Lebensnetz, das man als Biosphäre kennt. </seg> <seg id="12"> Man kann sich dieses kulturelle Lebensnetz als eine Ethnosphäre vorstellen. Ethnosphäre kann dabei als die Gesamtsumme aller Gedanken und Träume, Mythen Ideen, Inspirationen und Intuitionen, die von der menschlichen Vorstellungskraft seit den Anfängen des Bewusstseins hervorgebracht wurden, definiert werden. </seg> <seg id="13"> Die Ethnosphäre ist das großartige Vermächtnis der Menschheit. </seg> <seg id="14"> Sie ist das Symbol all dessen, was wir sind und wozu wir als erstaunlich wissbegierige Spezies fähig sind. </seg> <seg id="15"> Und genauso wie die Biosphäre stark abgetragen wurde, geschah dies mit der Ethnosphäre -- nur mit noch größerer Geschwindigkeit. </seg> <seg id="16"> Kein Biologe würde zum Beispiel wagen zu behaupten, dass 50% oder mehr aller Arten kurz vor dem Aussterben sind, da es einfach nicht stimmt. Und doch, dieses -- das apokalyptischste Szenarium auf dem Gebiet der biologischen Vielfalt -- entspricht kaum dem, was uns als optimistischstes Szenarium auf dem Gebiet der kulturellen Vielfalt bekannt ist. </seg> <seg id="17"> Und der entscheidende Indikator dafür ist das Aussterben der Sprachen. </seg>
4、创建训练集、验证集、测试集
python prepro.py --vocab_size 8000
部分运行结果:
trainer_interface.cc(615) LOG(INFO) Saving model: iwslt2016/segmented/bpe.model trainer_interface.cc(626) LOG(INFO) Saving vocabs: iwslt2016/segmented/bpe.vocab INFO:root:# Load trained bpe model INFO:root:# Segment INFO:root:Let's see how segmented data look like train1: ▁David ▁G all o : ▁Das ▁ist ▁Bill ▁L ange . ▁Ich ▁bin ▁Da ve ▁G all o . train2: ▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o . eval1: ▁Als ▁ich ▁11 ▁Jahre ▁alt ▁war , ▁wurde ▁ich ▁eines ▁Morgen s ▁von ▁den ▁Kl ängen ▁h eller ▁Freude ▁ge we ckt . eval2: ▁When ▁I ▁was ▁11 , ▁I ▁remember ▁w aking ▁up ▁one ▁morning ▁to ▁the ▁sound ▁of ▁j oy ▁in ▁my ▁house . test1: ▁Als ▁ich ▁in ▁meinen ▁20 ern ▁war , ▁hatte ▁ich ▁meine ▁erste ▁Psych other ap ie - P at ient in . INFO:root:Done
运行之后会有:
prepro.py中的内容如下:
# -*- coding: utf-8 -*- #/usr/bin/python3 ''' Feb. 2019 by kyubyong park. kbpark.linguist@gmail.com. https://www.github.com/kyubyong/transformer. Preprocess the iwslt 2016 datasets. ''' import os import errno import sentencepiece as spm import re from hparams import Hparams import logging logging.basicConfig(level=logging.INFO) def prepro(hp): """Load raw data -> Preprocessing -> Segmenting with sentencepice hp: hyperparams. argparse. """ logging.info("# Check if raw files exist") train1 = "iwslt2016/de-en/train.tags.de-en.de" train2 = "iwslt2016/de-en/train.tags.de-en.en" eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml" eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml" test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml" test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml" for f in (train1, train2, eval1, eval2, test1, test2): if not os.path.isfile(f): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f) logging.info("# Preprocessing") # train _prepro = lambda x: [line.strip() for line in open(x, 'r').read().split("\n") \ if not line.startswith("<")] prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2) assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match." # eval _prepro = lambda x: [re.sub("<[^>]+>", "", line).strip() \ for line in open(x, 'r').read().split("\n") \ if line.startswith("<seg id")] prepro_eval1, prepro_eval2 = _prepro(eval1), _prepro(eval2) assert len(prepro_eval1) == len(prepro_eval2), "Check if eval source and target files match." # test prepro_test1, prepro_test2 = _prepro(test1), _prepro(test2) assert len(prepro_test1) == len(prepro_test2), "Check if test source and target files match." logging.info("Let's see how preprocessed data look like") logging.info("prepro_train1:", prepro_train1[0]) logging.info("prepro_train2:", prepro_train2[0]) logging.info("prepro_eval1:", prepro_eval1[0]) logging.info("prepro_eval2:", prepro_eval2[0]) logging.info("prepro_test1:", prepro_test1[0]) logging.info("prepro_test2:", prepro_test2[0]) logging.info("# write preprocessed files to disk") os.makedirs("iwslt2016/prepro", exist_ok=True) def _write(sents, fname): with open(fname, 'w') as fout: fout.write("\n".join(sents)) _write(prepro_train1, "iwslt2016/prepro/train.de") _write(prepro_train2, "iwslt2016/prepro/train.en") _write(prepro_train1+prepro_train2, "iwslt2016/prepro/train") _write(prepro_eval1, "iwslt2016/prepro/eval.de") _write(prepro_eval2, "iwslt2016/prepro/eval.en") _write(prepro_test1, "iwslt2016/prepro/test.de") _write(prepro_test2, "iwslt2016/prepro/test.en") logging.info("# Train a joint BPE model with sentencepiece") os.makedirs("iwslt2016/segmented", exist_ok=True) train = '--input=iwslt2016/prepro/train --pad_id=0 --unk_id=1 \ --bos_id=2 --eos_id=3\ --model_prefix=iwslt2016/segmented/bpe --vocab_size={} \ --model_type=bpe'.format(hp.vocab_size) spm.SentencePieceTrainer.Train(train) logging.info("# Load trained bpe model") sp = spm.SentencePieceProcessor() sp.Load("iwslt2016/segmented/bpe.model") logging.info("# Segment") def _segment_and_write(sents, fname): with open(fname, "w") as fout: for sent in sents: pieces = sp.EncodeAsPieces(sent) fout.write(" ".join(pieces) + "\n") _segment_and_write(prepro_train1, "iwslt2016/segmented/train.de.bpe") _segment_and_write(prepro_train2, "iwslt2016/segmented/train.en.bpe") _segment_and_write(prepro_eval1, "iwslt2016/segmented/eval.de.bpe") _segment_and_write(prepro_eval2, "iwslt2016/segmented/eval.en.bpe") _segment_and_write(prepro_test1, "iwslt2016/segmented/test.de.bpe") logging.info("Let's see how segmented data look like") print("train1:", open("iwslt2016/segmented/train.de.bpe",'r').readline()) print("train2:", open("iwslt2016/segmented/train.en.bpe", 'r').readline()) print("eval1:", open("iwslt2016/segmented/eval.de.bpe", 'r').readline()) print("eval2:", open("iwslt2016/segmented/eval.en.bpe", 'r').readline()) print("test1:", open("iwslt2016/segmented/test.de.bpe", 'r').readline()) if __name__ == '__main__': hparams = Hparams() parser = hparams.parser hp = parser.parse_args() prepro(hp) logging.info("Done")
train中部分内容如下:
David Gallo: Das ist Bill Lange. Ich bin Dave Gallo. Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen. Wir haben ein paar der unglaublichsten Aufnahmen der Titanic, die man je gesehen hat,, und wir werden Ihnen nichts davon zeigen. Die Wahrheit ist, dass die Titanic – obwohl sie alle Kinokassenrekorde bricht – nicht gerade die aufregendste Geschichte vom Meer ist. Ich denke, das Problem ist, dass wir das Meer für zu selbstverständlich halten. Wenn man darüber nachdenkt, machen die Ozeane 75 % des Planeten aus. Der Großteil der Erde ist Meerwasser.
train.en.bpe中部分内容如下:
▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o . ▁And ▁we ' re ▁going ▁to ▁tell ▁you ▁some ▁stories ▁from ▁the ▁sea ▁here ▁in ▁video . ▁We ' ve ▁got ▁some ▁of ▁the ▁most ▁incredible ▁video ▁of ▁Tit an ic ▁that ' s ▁ever ▁been ▁seen , ▁and ▁we ' re ▁not ▁going ▁to ▁show ▁you ▁any ▁of ▁it . ▁The ▁truth ▁of ▁the ▁matter ▁is ▁that ▁the ▁Tit an ic ▁-- ▁even ▁though ▁it ' s ▁break ing ▁all ▁sorts ▁of ▁box ▁office ▁record s ▁-- ▁it ' s ▁not ▁the ▁most ▁exciting ▁story ▁from ▁the ▁sea . ▁And ▁the ▁problem , ▁I ▁think , ▁is ▁that ▁we ▁take ▁the ▁ocean ▁for ▁gr anted . ▁When ▁you ▁think ▁about ▁it , ▁the ▁oce ans ▁are ▁75 ▁percent ▁of ▁the ▁planet . ▁Most ▁of ▁the ▁planet ▁is ▁ocean ▁water .
bpe.vocab部分内容如下:
<pad> 0 <unk> 0 <s> 0 </s> 0 en -0 er -1 in -2 ▁t -3 ch -4 ▁a -5 ▁d -6 ▁w -7 ▁s -8 ▁th -9 nd -10 ie -11 es -12
5、train.py
# -*- coding: utf-8 -*- #/usr/bin/python3 ''' Feb. 2019 by kyubyong park. kbpark.linguist@gmail.com. https://www.github.com/kyubyong/transformer ''' import tensorflow as tf from model import Transformer from tqdm import tqdm from data_load import get_batch from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu import os from hparams import Hparams import math import logging logging.basicConfig(level=logging.INFO) logging.info("# hparams") hparams = Hparams() parser = hparams.parser hp = parser.parse_args() save_hparams(hp, hp.logdir) logging.info("# Prepare train/eval batches") train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2, hp.maxlen1, hp.maxlen2, hp.vocab, hp.batch_size, shuffle=True) eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2, 100000, 100000, hp.vocab, hp.batch_size, shuffle=False) # create a iterator of the correct shape and type iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes) xs, ys = iter.get_next() train_init_op = iter.make_initializer(train_batches) eval_init_op = iter.make_initializer(eval_batches) logging.info("# Load model") m = Transformer(hp) loss, train_op, global_step, train_summaries = m.train(xs, ys) y_hat, eval_summaries = m.eval(xs, ys) # y_hat = m.infer(xs, ys) logging.info("# Session") saver = tf.train.Saver(max_to_keep=hp.num_epochs) with tf.Session() as sess: ckpt = tf.train.latest_checkpoint(hp.logdir) if ckpt is None: logging.info("Initializing from scratch") sess.run(tf.global_variables_initializer()) save_variable_specs(os.path.join(hp.logdir, "specs")) else: saver.restore(sess, ckpt) summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph) sess.run(train_init_op) total_steps = hp.num_epochs * num_train_batches _gs = sess.run(global_step) for i in tqdm(range(_gs, total_steps+1)): _, _gs, _summary = sess.run([train_op, global_step, train_summaries]) epoch = math.ceil(_gs / num_train_batches) summary_writer.add_summary(_summary, _gs) if _gs and _gs % num_train_batches == 0: logging.info("epoch {} is done".format(epoch)) _loss = sess.run(loss) # train loss logging.info("# test evaluation") _, _eval_summaries = sess.run([eval_init_op, eval_summaries]) summary_writer.add_summary(_eval_summaries, _gs) logging.info("# get hypotheses") hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token) logging.info("# write results") model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss) if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir) translation = os.path.join(hp.evaldir, model_output) with open(translation, 'w') as fout: fout.write("\n".join(hypotheses)) logging.info("# calc bleu score and append it to translation") calc_bleu(hp.eval3, translation) logging.info("# save models") ckpt_name = os.path.join(hp.logdir, model_output) saver.save(sess, ckpt_name, global_step=_gs) logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name)) logging.info("# fall back to train mode") sess.run(train_init_op) summary_writer.close() logging.info("Done")
我们一行行来看:
首先调用了hparams.py中的函数:
import argparse class Hparams: parser = argparse.ArgumentParser() # prepro parser.add_argument('--vocab_size', default=32000, type=int) # train ## files parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe', help="german training segmented data") parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe', help="english training segmented data") parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe', help="german evaluation segmented data") parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe', help="english evaluation segmented data") parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en', help="english evaluation unsegmented data") ## vocabulary parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab', help="vocabulary file path") # training scheme parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--eval_batch_size', default=128, type=int) parser.add_argument('--lr', default=0.0003, type=float, help="learning rate") parser.add_argument('--warmup_steps', default=4000, type=int) parser.add_argument('--logdir', default="log/1", help="log directory") parser.add_argument('--num_epochs', default=20, type=int) parser.add_argument('--evaldir', default="eval/1", help="evaluation dir") # model parser.add_argument('--d_model', default=512, type=int, help="hidden dimension of encoder/decoder") parser.add_argument('--d_ff', default=2048, type=int, help="hidden dimension of feedforward layer") parser.add_argument('--num_blocks', default=6, type=int, help="number of encoder/decoder blocks") parser.add_argument('--num_heads', default=8, type=int, help="number of attention heads") parser.add_argument('--maxlen1', default=100, type=int, help="maximum length of a source sequence") parser.add_argument('--maxlen2', default=100, type=int, help="maximum length of a target sequence") parser.add_argument('--dropout_rate', default=0.3, type=float) parser.add_argument('--smoothing', default=0.1, type=float, help="label smoothing rate") # test parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe', help="german test segmented data") parser.add_argument('--test2', default='iwslt2016/prepro/test.en', help="english test data") parser.add_argument('--ckpt', help="checkpoint file path") parser.add_argument('--test_batch_size', default=128, type=int) parser.add_argument('--testdir', default="test/1", help="test result dir")
主要是一些超参数的设置。
然后是data_load.py中用来加载数据集:
# -*- coding: utf-8 -*- #/usr/bin/python3 ''' Feb. 2019 by kyubyong park. kbpark.linguist@gmail.com. https://www.github.com/kyubyong/transformer Note. if safe, entities on the source side have the prefix 1, and the target side 2, for convenience. For example, fpath1, fpath2 means source file path and target file path, respectively. ''' import tensorflow as tf from utils import calc_num_batches def load_vocab(vocab_fpath): '''Loads vocabulary file and returns idx<->token maps vocab_fpath: string. vocabulary file path. Note that these are reserved 0: <pad>, 1: <unk>, 2: <s>, 3: </s> Returns two dictionaries. ''' vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()] token2idx = {token: idx for idx, token in enumerate(vocab)} idx2token = {idx: token for idx, token in enumerate(vocab)} return token2idx, idx2token def load_data(fpath1, fpath2, maxlen1, maxlen2): '''Loads source and target data and filters out too lengthy samples. fpath1: source file path. string. fpath2: target file path. string. maxlen1: source sent maximum length. scalar. maxlen2: target sent maximum length. scalar. Returns sents1: list of source sents sents2: list of target sents ''' sents1, sents2 = [], [] with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2: for sent1, sent2 in zip(f1, f2): if len(sent1.split()) + 1 > maxlen1: continue # 1: </s> if len(sent2.split()) + 1 > maxlen2: continue # 1: </s> sents1.append(sent1.strip()) sents2.append(sent2.strip()) return sents1, sents2 def encode(inp, type, dict): '''Converts string to number. Used for `generator_fn`. inp: 1d byte array. type: "x" (source side) or "y" (target side) dict: token2idx dictionary Returns list of numbers ''' inp_str = inp.decode("utf-8") if type=="x": tokens = inp_str.split() + ["</s>"] else: tokens = ["<s>"] + inp_str.split() + ["</s>"] x = [dict.get(t, dict["<unk>"]) for t in tokens] return x def generator_fn(sents1, sents2, vocab_fpath): '''Generates training / evaluation data sents1: list of source sents sents2: list of target sents vocab_fpath: string. vocabulary file path. yields xs: tuple of x: list of source token ids in a sent x_seqlen: int. sequence length of x sent1: str. raw source (=input) sentence labels: tuple of decoder_input: decoder_input: list of encoded decoder inputs y: list of target token ids in a sent y_seqlen: int. sequence length of y sent2: str. target sentence ''' token2idx, _ = load_vocab(vocab_fpath) for sent1, sent2 in zip(sents1, sents2): x = encode(sent1, "x", token2idx) y = encode(sent2, "y", token2idx) decoder_input, y = y[:-1], y[1:] x_seqlen, y_seqlen = len(x), len(y) yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2) def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False): '''Batchify data sents1: list of source sents sents2: list of target sents vocab_fpath: string. vocabulary file path. batch_size: scalar shuffle: boolean Returns xs: tuple of x: int32 tensor. (N, T1) x_seqlens: int32 tensor. (N,) sents1: str tensor. (N,) ys: tuple of decoder_input: int32 tensor. (N, T2) y: int32 tensor. (N, T2) y_seqlen: int32 tensor. (N, ) sents2: str tensor. (N,) ''' shapes = (([None], (), ()), ([None], [None], (), ())) types = ((tf.int32, tf.int32, tf.string), (tf.int32, tf.int32, tf.int32, tf.string)) paddings = ((0, 0, ''), (0, 0, 0, '')) dataset = tf.data.Dataset.from_generator( generator_fn, output_shapes=shapes, output_types=types, args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays if shuffle: # for training dataset = dataset.shuffle(128*batch_size) dataset = dataset.repeat() # iterate forever dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1) return dataset def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False): '''Gets training / evaluation mini-batches fpath1: source file path. string. fpath2: target file path. string. maxlen1: source sent maximum length. scalar. maxlen2: target sent maximum length. scalar. vocab_fpath: string. vocabulary file path. batch_size: scalar shuffle: boolean Returns batches num_batches: number of mini-batches num_samples ''' sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2) batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle) num_batches = calc_num_batches(len(sents1), batch_size) return batches, num_batches, len(sents1)
6、看一下相关模型model.py
# -*- coding: utf-8 -*- # /usr/bin/python3 ''' Feb. 2019 by kyubyong park. kbpark.linguist@gmail.com. https://www.github.com/kyubyong/transformer Transformer network ''' import tensorflow as tf from data_load import load_vocab from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, label_smoothing, noam_scheme from utils import convert_idx_to_token_tensor from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO) class Transformer: ''' xs: tuple of x: int32 tensor. (N, T1) x_seqlens: int32 tensor. (N,) sents1: str tensor. (N,) ys: tuple of decoder_input: int32 tensor. (N, T2) y: int32 tensor. (N, T2) y_seqlen: int32 tensor. (N, ) sents2: str tensor. (N,) training: boolean. ''' def __init__(self, hp): self.hp = hp self.token2idx, self.idx2token = load_vocab(hp.vocab) self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True) def encode(self, xs, training=True): ''' Returns memory: encoder outputs. (N, T1, d_model) ''' with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): x, seqlens, sents1 = xs # src_masks src_masks = tf.math.equal(x, 0) # (N, T1) # embedding enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model) enc *= self.hp.d_model**0.5 # scale enc += positional_encoding(enc, self.hp.maxlen1) enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training) ## Blocks for i in range(self.hp.num_blocks): with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE): # self-attention enc = multihead_attention(queries=enc, keys=enc, values=enc, key_masks=src_masks, num_heads=self.hp.num_heads, dropout_rate=self.hp.dropout_rate, training=training, causality=False) # feed forward enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model]) memory = enc return memory, sents1, src_masks def decode(self, ys, memory, src_masks, training=True): ''' memory: encoder outputs. (N, T1, d_model) src_masks: (N, T1) Returns logits: (N, T2, V). float32. y_hat: (N, T2). int32 y: (N, T2). int32 sents2: (N,). string. ''' with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): decoder_inputs, y, seqlens, sents2 = ys # tgt_masks tgt_masks = tf.math.equal(decoder_inputs, 0) # (N, T2) # embedding dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs) # (N, T2, d_model) dec *= self.hp.d_model ** 0.5 # scale dec += positional_encoding(dec, self.hp.maxlen2) dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training) # Blocks for i in range(self.hp.num_blocks): with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE): # Masked self-attention (Note that causality is True at this time) dec = multihead_attention(queries=dec, keys=dec, values=dec, key_masks=tgt_masks, num_heads=self.hp.num_heads, dropout_rate=self.hp.dropout_rate, training=training, causality=True, scope="self_attention") # Vanilla attention dec = multihead_attention(queries=dec, keys=memory, values=memory, key_masks=src_masks, num_heads=self.hp.num_heads, dropout_rate=self.hp.dropout_rate, training=training, causality=False, scope="vanilla_attention") ### Feed Forward dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model]) # Final linear projection (embedding weights are shared) weights = tf.transpose(self.embeddings) # (d_model, vocab_size) logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size) y_hat = tf.to_int32(tf.argmax(logits, axis=-1)) return logits, y_hat, y, sents2 def train(self, xs, ys): ''' Returns loss: scalar. train_op: training operation global_step: scalar. summaries: training summary node ''' # forward memory, sents1, src_masks = self.encode(xs) logits, preds, y, sents2 = self.decode(ys, memory, src_masks) # train scheme y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size)) ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_) nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"])) # 0: <pad> loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7) global_step = tf.train.get_or_create_global_step() lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps) optimizer = tf.train.AdamOptimizer(lr) train_op = optimizer.minimize(loss, global_step=global_step) tf.summary.scalar('lr', lr) tf.summary.scalar("loss", loss) tf.summary.scalar("global_step", global_step) summaries = tf.summary.merge_all() return loss, train_op, global_step, summaries def eval(self, xs, ys): '''Predicts autoregressively At inference, input ys is ignored. Returns y_hat: (N, T2) ''' decoder_inputs, y, y_seqlen, sents2 = ys decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"] ys = (decoder_inputs, y, y_seqlen, sents2) memory, sents1, src_masks = self.encode(xs, False) logging.info("Inference graph is being built. Please be patient.") for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2) # monitor a random sample n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32) sent1 = sents1[n] pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token) sent2 = sents2[n] tf.summary.text("sent1", sent1) tf.summary.text("pred", pred) tf.summary.text("sent2", sent2) summaries = tf.summary.merge_all() return y_hat, summaries