关于LayoutLM中FunsdDataset
在huggingface写layoutLM在FUNSD上微调的任务时,用到了layoutlm.data.funsd 中 FunsdDataset, InputFeatures 两个类。
微软开源代码地址 unilm/funsd.py at master · microsoft/unilm (github.com)
huggingface的引用layout代码的开源地址 unilm/funsd.py at master · NielsRogge/unilm (github.com)
以下是注释理解后的代码
import logging import os import torch from torch.utils.data import Dataset logger = logging.getLogger(__name__) # 主类 class FunsdDataset(Dataset): def __init__(self, args, tokenizer, labels, pad_token_label_id, mode): # 以下和分布式训练有关,暂不管 if args.local_rank not in [-1, 0] and mode == "train": torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache # Load data features from cache or dataset file cached_features_file = os.path.join( args.data_dir, "cached_{}_{}_{}".format( mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset file at %s", args.data_dir) ''' 进入主题,开始读取数据 首先通过 函数,从文件中读取example 然后再通过 函数,将example转化为features ''' examples = read_examples_from_file(args.data_dir, mode) features = convert_examples_to_features( examples, # 所有的文档样本内容信息列表 labels, # 所有的标签全集 args.max_seq_length, # 最大序列长度 tokenizer, # 分词器 cls_token_at_end=bool(args.model_type in ["xlnet"]), # xlnet has a cls token at the end cls_token=tokenizer.cls_token, cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0, sep_token=tokenizer.sep_token, sep_token_extra=bool(args.model_type in ["roberta"]), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, pad_token_label_id=pad_token_label_id, ) # 这和咱们有点关系,我们的local_rank为-1,所以会执行到这里,把读取到的数据,放到缓存文件中。 if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) if args.local_rank == 0 and mode == "train": torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache self.features = features # Convert to Tensors and build dataset self.all_input_ids = torch.tensor( [f.input_ids for f in features], dtype=torch.long ) self.all_input_mask = torch.tensor( [f.input_mask for f in features], dtype=torch.long ) self.all_segment_ids = torch.tensor( [f.segment_ids for f in features], dtype=torch.long ) self.all_label_ids = torch.tensor( [f.label_ids for f in features], dtype=torch.long ) self.all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long) def __len__(self): return len(self.features) def __getitem__(self, index): return ( self.all_input_ids[index], self.all_input_mask[index], self.all_segment_ids[index], self.all_label_ids[index], self.all_bboxes[index], ) class InputExample(object): """A single training/test example for token classification.一个文档图像的内容信息。""" def __init__(self, guid, words, labels, boxes, actual_bboxes, file_name, page_size): """Constructs a InputExample. Args: guid: Unique id for the example. words: list. The words of the sequence. labels: (Optional) list. The labels for each word of the sequence. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid # 文档序号 self.words = words # Word 列表 self.labels = labels # 标签列表 self.boxes = boxes # 缩放后的文本框列表 self.actual_bboxes = actual_bboxes # 实际大小文本框信息 self.file_name = file_name # 文件名 self.page_size = page_size # 文档图像大小 class InputFeatures(object): """A single set of features of data.""" def __init__( self, input_ids, # token的索引 input_mask, # 区分句子和PAD token,pad是0,其他token都是1 segment_ids, # 区分【CLS】【 句子token和SEP】【PAD】三种值,这里1,0,0 label_ids, # 每个token对应的标签,【分词后后续词,CLS、SEP、PAD特殊token】都是pad_token_label_id boxes, # 每个tokne对应词的缩放文本框 actual_bboxes, # 每个tokne对应词的实际文本框 file_name, # 文件名 page_size, # 文档图像大小 ): assert ( 0 <= all(boxes) <= 1000 ), "Error with input bbox ({}): the coordinate value is not between 0 and 1000".format( boxes ) self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_ids = label_ids self.boxes = boxes self.actual_bboxes = actual_bboxes self.file_name = file_name self.page_size = page_size def read_examples_from_file(data_dir, mode): ''' 参数 data_dir 数据集目录 mode 训练或者测试 train or test ''' # 获得三个文件的全路径 file_path = os.path.join(data_dir, "{}.txt".format(mode)) box_file_path = os.path.join(data_dir, "{}_box.txt".format(mode)) image_file_path = os.path.join(data_dir, "{}_image.txt".format(mode)) guid_index = 1 # 表示读取的文档图像的序号 examples = [] with open(file_path, encoding="utf-8") as f, open(box_file_path, encoding="utf-8") as fb, open(image_file_path, encoding="utf-8") as fi: # 初始化 words = [] # Word boxes = [] # 文本框 actual_bboxes = [] # 文本框 file_name = None # 文件名 page_size = None # 文档图像大小 labels = [] # Word 标签 # 同时分别读取三个文件的每一行 for line, bline, iline in zip(f, fb, fi): # 如果读到开始行,空行就重新初始化,表示开始读取新的文档中的信息 if line.startswith("-DOCSTART-") or line == "" or line == "\n": if words: examples.append( InputExample( guid="{}-{}".format(mode, guid_index), words=words, labels=labels, boxes=boxes, actual_bboxes=actual_bboxes, file_name=file_name, page_size=page_size, ) ) guid_index += 1 words = [] boxes = [] actual_bboxes = [] file_name = None page_size = None labels = [] # 对于正常的数据行,则 else: # 切割,确保分割后的数据符合格式 splits = line.split("\t") bsplits = bline.split("\t") isplits = iline.split("\t") assert len(splits) == 2 #Word #Word类型,采用BIOES方法进行分类。 assert len(bsplits) == 2 #Word #Word缩放到0-1000的坐标框 assert len(isplits) == 4 #Word #原始坐标框 #原始图像大小 #图像文件名 assert splits[0] == bsplits[0] # words.append(splits[0]) if len(splits) > 1: labels.append(splits[-1].replace("\n", "")) box = bsplits[-1].replace("\n", "") box = [int(b) for b in box.split()] boxes.append(box) actual_bbox = [int(b) for b in isplits[1].split()] actual_bboxes.append(actual_bbox) page_size = [int(i) for i in isplits[2].split()] file_name = isplits[3].strip() else: # Examples could have no label for mode = "test" labels.append("O") # 读取结束,如果words中还有内容,则意味最后读取到的的文档信息还没有保持,将其继续存入到example中 if words: examples.append( InputExample( guid="%s-%d".format(mode, guid_index), words=words, labels=labels, boxes=boxes, actual_bboxes=actual_bboxes, file_name=file_name, page_size=page_size, ) ) # 所有读取到的信息,再添加到examples列表前都会实例化成InputExample类 return examples def convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, cls_token_at_end=False, cls_token="[CLS]", cls_token_segment_id=1, # segment_ids中,属于CLS的值-1 sep_token="[SEP]", sep_token_extra=False, pad_on_left=False, pad_token=0, cls_token_box=[0, 0, 0, 0], sep_token_box=[1000, 1000, 1000, 1000], pad_token_box=[0, 0, 0, 0], pad_token_segment_id=0, # segment_ids中,PAD的值-0 pad_token_label_id=-1, # label_ids中,token分词后后续词,CLS、SEP、PAD特殊token都是这个标签种类 sequence_a_segment_id=0, # segment_ids中,所有句子token和SEP的值-0 mask_padding_with_zero=True, # input_mask中,pad是0,其他token都是1,包括其他特殊token。 ): """ Loads a data file into a list of `InputBatch`s `cls_token_at_end` define the location of the CLS token: - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) """ label_map = {label: i for i, label in enumerate(label_list)} features = [] for (ex_index, example) in enumerate(examples): file_name = example.file_name page_size = example.page_size width, height = page_size if ex_index % 10000 == 0: logger.info("Writing example %d of %d", ex_index, len(examples)) # 对每个word级别的信息进一步处理,加工成token级别的信息 tokens = [] token_boxes = [] actual_bboxes = [] label_ids = [] # 遍历所有文档,将每个文档里的所有word拆分为token级别 for word, label, box, actual_bbox in zip( example.words, example.labels, example.boxes, example.actual_bboxes ): word_tokens = tokenizer.tokenize(word) tokens.extend(word_tokens) token_boxes.extend([box] * len(word_tokens)) actual_bboxes.extend([actual_bbox] * len(word_tokens)) # Use the real label id for the first token of the word, and padding ids for the remaining tokens # 对于分此后产生的多个token,只给第一个token原始标签,后续token的标签使用pad的标签。 label_ids.extend( [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) ) # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. special_tokens_count = 3 if sep_token_extra else 2 # 如果一个文档中的token和特殊标记加起来大于最大序列长度,则要进行截断 # 所以最大序列长度,往往要取文档分词后一般达不到的数 if len(tokens) > max_seq_length - special_tokens_count: tokens = tokens[: (max_seq_length - special_tokens_count)] token_boxes = token_boxes[: (max_seq_length - special_tokens_count)] actual_bboxes = actual_bboxes[: (max_seq_length - special_tokens_count)] label_ids = label_ids[: (max_seq_length - special_tokens_count)] # The convention in BERT is: # (a) For sequence pairs: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # (b) For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # # Where "type_ids" are used to indicate whether this is the first # sequence or the second sequence. The embedding vectors for `type=0` and # `type=1` were learned during pre-training and are added to the wordpiece # embedding vector (and position vector). This is not *strictly* necessary # since the [SEP] token unambiguously separates the sequences, but it makes # it easier for the model to learn the concept of sequences. # # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. # 在tokens中,添加CLS和SEP这两个TOKEN,根据配置不同,分别有1-2个SEP,将CLS放在头或者尾 tokens += [sep_token] token_boxes += [sep_token_box] actual_bboxes += [[0, 0, width, height]] label_ids += [pad_token_label_id] if sep_token_extra: # roberta uses an extra separator b/w pairs of sentences tokens += [sep_token] token_boxes += [sep_token_box] actual_bboxes += [[0, 0, width, height]] label_ids += [pad_token_label_id] # segment_ids = [sequence_a_segment_id] * len(tokens) if cls_token_at_end: tokens += [cls_token] token_boxes += [cls_token_box] actual_bboxes += [[0, 0, width, height]] label_ids += [pad_token_label_id] segment_ids += [cls_token_segment_id] else: tokens = [cls_token] + tokens token_boxes = [cls_token_box] + token_boxes actual_bboxes = [[0, 0, width, height]] + actual_bboxes label_ids = [pad_token_label_id] + label_ids segment_ids = [cls_token_segment_id] + segment_ids # 根据token序列将其转化为索引序列 input_ids = tokenizer.convert_tokens_to_ids(tokens) # input_mask中除了PAD token,其他全是1 # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) # 填充PAD # Zero-pad up to the sequence length. padding_length = max_seq_length - len(input_ids) if pad_on_left: input_ids = ([pad_token] * padding_length) + input_ids input_mask = ( [0 if mask_padding_with_zero else 1] * padding_length ) + input_mask segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids label_ids = ([pad_token_label_id] * padding_length) + label_ids token_boxes = ([pad_token_box] * padding_length) + token_boxes else: input_ids += [pad_token] * padding_length input_mask += [0 if mask_padding_with_zero else 1] * padding_length segment_ids += [pad_token_segment_id] * padding_length label_ids += [pad_token_label_id] * padding_length token_boxes += [pad_token_box] * padding_length assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length assert len(token_boxes) == max_seq_length # 至此处理完一个文档对应的example的信息,将其全部序列化为token级别的信息 if ex_index < 5: logger.info("*** Example ***") logger.info("guid: %s", example.guid) logger.info("tokens: %s", " ".join([str(x) for x in tokens])) logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) logger.info("boxes: %s", " ".join([str(x) for x in token_boxes])) logger.info("actual_bboxes: %s", " ".join([str(x) for x in actual_bboxes])) # 将其添加进实例化,并放入features列表。 features.append( InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, boxes=token_boxes, actual_bboxes=actual_bboxes, file_name=file_name, page_size=page_size, ) ) return features