Loading

关于LayoutLM中FunsdDataset

在huggingface写layoutLM在FUNSD上微调的任务时,用到了layoutlm.data.funsd 中 FunsdDataset, InputFeatures 两个类。

微软开源代码地址 unilm/funsd.py at master · microsoft/unilm (github.com)

huggingface的引用layout代码的开源地址 unilm/funsd.py at master · NielsRogge/unilm (github.com)

以下是注释理解后的代码

import logging
import os

import torch
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)

#  主类
class FunsdDataset(Dataset):
    
    def __init__(self, args, tokenizer, labels, pad_token_label_id, mode):
        # 以下和分布式训练有关,暂不管
        if args.local_rank not in [-1, 0] and mode == "train":
            torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

        # Load data features from cache or dataset file
        cached_features_file = os.path.join(
            args.data_dir,
            "cached_{}_{}_{}".format(
                mode,
                list(filter(None, args.model_name_or_path.split("/"))).pop(),
                str(args.max_seq_length),
            ),
        )
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s", cached_features_file)
            features = torch.load(cached_features_file)
        else:
            logger.info("Creating features from dataset file at %s", args.data_dir)
            
            '''
            进入主题,开始读取数据
            首先通过 函数,从文件中读取example
            然后再通过 函数,将example转化为features
            '''
            examples = read_examples_from_file(args.data_dir, mode) 
            features = convert_examples_to_features(
                examples,   # 所有的文档样本内容信息列表
                labels,     # 所有的标签全集
                args.max_seq_length,    # 最大序列长度
                tokenizer,  # 分词器


                cls_token_at_end=bool(args.model_type in ["xlnet"]),
                # xlnet has a cls token at the end
                cls_token=tokenizer.cls_token,
                cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
                sep_token=tokenizer.sep_token,
                sep_token_extra=bool(args.model_type in ["roberta"]),
                # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                pad_on_left=bool(args.model_type in ["xlnet"]),
                # pad on the left for xlnet
                pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
                pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
                pad_token_label_id=pad_token_label_id,
            )
            # 这和咱们有点关系,我们的local_rank为-1,所以会执行到这里,把读取到的数据,放到缓存文件中。
            if args.local_rank in [-1, 0]:
                logger.info("Saving features into cached file %s", cached_features_file)
                torch.save(features, cached_features_file)

        if args.local_rank == 0 and mode == "train":
            torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

        self.features = features
        # Convert to Tensors and build dataset
        self.all_input_ids = torch.tensor(
            [f.input_ids for f in features], dtype=torch.long
        )
        self.all_input_mask = torch.tensor(
            [f.input_mask for f in features], dtype=torch.long
        )
        self.all_segment_ids = torch.tensor(
            [f.segment_ids for f in features], dtype=torch.long
        )
        self.all_label_ids = torch.tensor(
            [f.label_ids for f in features], dtype=torch.long
        )
        self.all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        return (
            self.all_input_ids[index],
            self.all_input_mask[index],
            self.all_segment_ids[index],
            self.all_label_ids[index],
            self.all_bboxes[index],
        )


class InputExample(object):
    """A single training/test example for token classification.一个文档图像的内容信息。"""

    def __init__(self, guid, words, labels, boxes, actual_bboxes, file_name, page_size):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            words: list. The words of the sequence.
            labels: (Optional) list. The labels for each word of the sequence. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid    # 文档序号
        self.words = words  # Word 列表
        self.labels = labels    # 标签列表
        self.boxes = boxes      # 缩放后的文本框列表
        self.actual_bboxes = actual_bboxes  # 实际大小文本框信息
        self.file_name = file_name      # 文件名
        self.page_size = page_size      # 文档图像大小


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(
        self,
        input_ids,      # token的索引
        input_mask,     # 区分句子和PAD token,pad是0,其他token都是1
        segment_ids,    # 区分【CLS】【 句子token和SEP】【PAD】三种值,这里1,0,0
        label_ids,      # 每个token对应的标签,【分词后后续词,CLS、SEP、PAD特殊token】都是pad_token_label_id
        boxes,          # 每个tokne对应词的缩放文本框
        actual_bboxes,  # 每个tokne对应词的实际文本框
        file_name,      # 文件名
        page_size,      # 文档图像大小
    ):
        assert (
            0 <= all(boxes) <= 1000
            ), "Error with input bbox ({}): the coordinate value is not between 0 and 1000".format(
            boxes
        )
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.boxes = boxes
        self.actual_bboxes = actual_bboxes
        self.file_name = file_name
        self.page_size = page_size


def read_examples_from_file(data_dir, mode):
    '''
    参数
    data_dir 数据集目录
    mode 训练或者测试 train or test
    '''
    # 获得三个文件的全路径
    file_path = os.path.join(data_dir, "{}.txt".format(mode))
    box_file_path = os.path.join(data_dir, "{}_box.txt".format(mode))
    image_file_path = os.path.join(data_dir, "{}_image.txt".format(mode))
    
    guid_index = 1  # 表示读取的文档图像的序号
    examples = []
    with open(file_path, encoding="utf-8") as f, open(box_file_path, encoding="utf-8") as fb, open(image_file_path, encoding="utf-8") as fi:
        # 初始化
        words = []  # Word 
        boxes = []  # 文本框
        actual_bboxes = []  # 文本框
        file_name = None    # 文件名
        page_size = None    # 文档图像大小
        labels = []         # Word 标签

        # 同时分别读取三个文件的每一行
        for line, bline, iline in zip(f, fb, fi):
            # 如果读到开始行,空行就重新初始化,表示开始读取新的文档中的信息
            if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                if words:
                    examples.append(
                        InputExample(
                            guid="{}-{}".format(mode, guid_index),
                            words=words,
                            labels=labels,
                            boxes=boxes,
                            actual_bboxes=actual_bboxes,
                            file_name=file_name,
                            page_size=page_size,
                        )
                    )
                    guid_index += 1
                    words = []
                    boxes = []
                    actual_bboxes = []
                    file_name = None
                    page_size = None
                    labels = []
            # 对于正常的数据行,则
            else:
                # 切割,确保分割后的数据符合格式
                splits = line.split("\t")
                bsplits = bline.split("\t")
                isplits = iline.split("\t")
                assert len(splits) == 2     #Word   #Word类型,采用BIOES方法进行分类。
                assert len(bsplits) == 2    #Word   #Word缩放到0-1000的坐标框
                assert len(isplits) == 4    #Word   #原始坐标框       #原始图像大小    #图像文件名
                assert splits[0] == bsplits[0]
                
                # 
                words.append(splits[0])
                if len(splits) > 1:
                    labels.append(splits[-1].replace("\n", ""))
                    box = bsplits[-1].replace("\n", "")
                    box = [int(b) for b in box.split()]
                    boxes.append(box)
                    actual_bbox = [int(b) for b in isplits[1].split()]
                    actual_bboxes.append(actual_bbox)
                    page_size = [int(i) for i in isplits[2].split()]
                    file_name = isplits[3].strip()
                else:
                    # Examples could have no label for mode = "test"
                    labels.append("O")
        # 读取结束,如果words中还有内容,则意味最后读取到的的文档信息还没有保持,将其继续存入到example中
        if words:
            examples.append(
                InputExample(
                    guid="%s-%d".format(mode, guid_index),
                    words=words,
                    labels=labels,
                    boxes=boxes,
                    actual_bboxes=actual_bboxes,
                    file_name=file_name,
                    page_size=page_size,
                )
            )
        # 所有读取到的信息,再添加到examples列表前都会实例化成InputExample类
    return examples


def convert_examples_to_features(
    examples,
    label_list,
    max_seq_length,
    tokenizer,
    cls_token_at_end=False,
    cls_token="[CLS]",
    cls_token_segment_id=1,     # segment_ids中,属于CLS的值-1
    sep_token="[SEP]",
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=0,
    cls_token_box=[0, 0, 0, 0],
    sep_token_box=[1000, 1000, 1000, 1000],
    pad_token_box=[0, 0, 0, 0],
    pad_token_segment_id=0,     # segment_ids中,PAD的值-0
    pad_token_label_id=-1,      # label_ids中,token分词后后续词,CLS、SEP、PAD特殊token都是这个标签种类
    sequence_a_segment_id=0,    # segment_ids中,所有句子token和SEP的值-0
    mask_padding_with_zero=True,    # input_mask中,pad是0,其他token都是1,包括其他特殊token。
):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """

    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in enumerate(examples):
        file_name = example.file_name
        page_size = example.page_size
        width, height = page_size
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d", ex_index, len(examples))

        # 对每个word级别的信息进一步处理,加工成token级别的信息
        tokens = []
        token_boxes = []
        actual_bboxes = []
        label_ids = []
        # 遍历所有文档,将每个文档里的所有word拆分为token级别
        for word, label, box, actual_bbox in zip(
            example.words, example.labels, example.boxes, example.actual_bboxes
        ):
            word_tokens = tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            token_boxes.extend([box] * len(word_tokens))
            actual_bboxes.extend([actual_bbox] * len(word_tokens))
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            # 对于分此后产生的多个token,只给第一个token原始标签,后续token的标签使用pad的标签。
            label_ids.extend(
                [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)
            )

        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
        special_tokens_count = 3 if sep_token_extra else 2
        # 如果一个文档中的token和特殊标记加起来大于最大序列长度,则要进行截断
        # 所以最大序列长度,往往要取文档分词后一般达不到的数
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
            actual_bboxes = actual_bboxes[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids:   0   0   0   0  0     0   0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.

        # 在tokens中,添加CLS和SEP这两个TOKEN,根据配置不同,分别有1-2个SEP,将CLS放在头或者尾
        tokens += [sep_token]
        token_boxes += [sep_token_box]
        actual_bboxes += [[0, 0, width, height]]
        label_ids += [pad_token_label_id]
        if sep_token_extra:
            # roberta uses an extra separator b/w pairs of sentences
            tokens += [sep_token]
            token_boxes += [sep_token_box]
            actual_bboxes += [[0, 0, width, height]]
            label_ids += [pad_token_label_id]
        # 
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if cls_token_at_end:
            tokens += [cls_token]
            token_boxes += [cls_token_box]
            actual_bboxes += [[0, 0, width, height]]
            label_ids += [pad_token_label_id]
            segment_ids += [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            token_boxes = [cls_token_box] + token_boxes
            actual_bboxes = [[0, 0, width, height]] + actual_bboxes
            label_ids = [pad_token_label_id] + label_ids
            segment_ids = [cls_token_segment_id] + segment_ids

        # 根据token序列将其转化为索引序列
        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # input_mask中除了PAD token,其他全是1
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # 填充PAD
        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = (
                [0 if mask_padding_with_zero else 1] * padding_length
            ) + input_mask
            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            label_ids = ([pad_token_label_id] * padding_length) + label_ids
            token_boxes = ([pad_token_box] * padding_length) + token_boxes
        else:
            input_ids += [pad_token] * padding_length
            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
            segment_ids += [pad_token_segment_id] * padding_length
            label_ids += [pad_token_label_id] * padding_length
            token_boxes += [pad_token_box] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        assert len(token_boxes) == max_seq_length

        # 至此处理完一个文档对应的example的信息,将其全部序列化为token级别的信息
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s", example.guid)
            logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
            logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
            logger.info("boxes: %s", " ".join([str(x) for x in token_boxes]))
            logger.info("actual_bboxes: %s", " ".join([str(x) for x in actual_bboxes]))
        # 将其添加进实例化,并放入features列表。
        features.append(
            InputFeatures(
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                label_ids=label_ids,
                boxes=token_boxes,
                actual_bboxes=actual_bboxes,
                file_name=file_name,
                page_size=page_size,
            )
        )
    return features

 

posted @ 2022-07-07 16:13  丘野  阅读(169)  评论(0编辑  收藏  举报