RFS的策略,对类别c,首先有一个统计量\(f_{c}\)。它的含义是,统计出那些起码包含一个类别c实例的图片所占的所有图片的频率。然后通过公式计算\(r_{c} = max(1, sqrt(t / f_{c}))\)。在这里,t是一个超参数。一般来讲,t = 0.001。

计算出\(f_c\)之后,我们需要做的事情是,对每一张图片i而言,\(r_i = max_{i \in c}r_c\)。这样的话,在每一个epoch里面,SGD data sampler都是造一个random的permutation。每一张图片重复的次数都是\(r_i\)次。

关于RFS的原理是没有一个准则的,它仅仅是一个启发式的算法(heuristic)。当一个instance的\(f_{c}\)减小\(\lambda\)倍,那么它的重复次数就会被增大\(sqrt(1 / \lambda)\)倍。


import itertools
import math
from collections import defaultdict
from typing import Optional
import torch
from torch.utils.data.sampler import Sampler

from detectron2.utils import comm

class TrainingSampler(Sampler):
    在训练的时候,我们仅仅关心训练数据的"infinite stream",这个组件实现的就是一个"infinite"的下标流。
    流,包含的是'shuffle(range(size)) + shuffle(range(size) + ...' 如果shuffle参数设置为True
    否则就是'range(size) + range(size) + ...' 如果shuffle被设置为False.
	def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
		self._size = size
        assert size > 0
        self._shuffle = shuffle
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
	def __iter__(self): # 返回值是生成器
        start = self._rank
        	这里介绍一下itertools.islice。islice(iterable, [start,] stop [, step])
        	yield from 指的就是返回另一个生成器。后面只要是一个可迭代的就可以了。
        	yield from iterable <=> for i in iterable: yield i
        yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) # 分到的indices并不连续

    def _infinite_indices(self): #本函数返回的是一个生成器
        g = torch.Generator()
        while True:
            if self._shuffle:
                yield from torch.randperm(self._size, generator=g).tolist()
                yield from torch.arange(self._size).tolist()

 class RepeatFactorTrainingSampler(Sampler):

    def __init__(self, repeat_factors, *, shuffle=True, seed=None):
            repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
                full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
            shuffle (bool): whether to shuffle the indices or not
            seed (int): the initial seed of the shuffle. Must be the same
                across all workers. If None, will use a random seed shared
                among workers (require synchronization among all workers).
        self._shuffle = shuffle
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

        # Split into whole number (_int_part) and fractional (_frac_part) parts. 分成整数部分与小数部分。
        self._int_part = torch.trunc(repeat_factors)
        self._frac_part = repeat_factors - self._int_part

    def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
            dataset_dicts (list[dict]): d2的注释格式
            repeat_thresh (float): 在threshhold之下的类别是应该重复的。

        # 对每一个类来讲,要算出它的frequency
        category_freq = defaultdict(int)
        for dataset_dict in dataset_dicts:  #对每一张图片
            cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
            for cat_id in cat_ids:
                category_freq[cat_id] += 1
        num_images = len(dataset_dicts)
        for k, v in category_freq.items():
            category_freq[k] = v / num_images

        # 2. For each category c, compute the category-level repeat factor:
        #    r(c) = max(1, sqrt(t / f(c)))
        category_rep = {
            cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
            for cat_id, cat_freq in category_freq.items()

        # 3. 返回的一个tensor,里面每一个值对应的是每一个index对应的repeat次数。
        #    r(I) = max_{c in I} r(c)
        rep_factors = []
        for dataset_dict in dataset_dicts:
            cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
            rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)

        return torch.tensor(rep_factors, dtype=torch.float32)
	def _get_epoch_indices(self, generator):
        Create a list of dataset indices (with repeats) to use for one epoch.

            generator (torch.Generator): 随机数生成器,用来概率取整。

        # 这里说明一件事情就是,由于reapeat factor是一个小数,所以我们
        # 只能够采用一个带有概率的上取整或者下取整。这样能在期望意义上,
        # 等价的。
        rands = torch.rand(len(self._frac_part), generator=generator)
        rep_factors = self._int_part + (rands < self._frac_part).float()
        # Construct a list of indices in which we repeat images as specified
        indices = []
        for dataset_index, rep_factor in enumerate(rep_factors):
            indices.extend([dataset_index] * int(rep_factor.item()))
        return torch.tensor(indices, dtype=torch.int64)

    def __iter__(self):
        start = self._rank
        yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)

    def _infinite_indices(self):
        g = torch.Generator()
        while True:
            # Sample indices with repeats determined by stochastic rounding; each
            # "epoch" may have a slightly different size due to the rounding.
            indices = self._get_epoch_indices(g)
            if self._shuffle:
                randperm = torch.randperm(len(indices), generator=g)
                yield from indices[randperm].tolist()
                yield from indices.tolist()
class InferenceSampler(Sampler):
    Produce indices for inference across all workers.
    Inference needs to run on the __exact__ set of samples,
    therefore when the total number of samples is not divisible by the number of workers,
    this sampler produces different number of samples on different workers.

    def __init__(self, size: int):
            size (int): the total number of data of the underlying dataset to sample from
        self._size = size
        assert size > 0
        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

        shard_size = (self._size - 1) // self._world_size + 1
        begin = shard_size * self._rank
        end = min(shard_size * (self._rank + 1), self._size)
        self._local_indices = range(begin, end)

    def __iter__(self):
        yield from self._local_indices

    def __len__(self):
        return len(self._local_indices)
