Albumentations库使用

1 Albumentations库介绍

一个好用的开源图像处理库,适用于对RGB、灰度图、多光谱图像,以及对应的mask、边界框和关键点同时变换。通常用于数据增广,是PyTorch生态系统的一部分。

主页:https://albumentations.ai/

2 核心点

支持的变换:https://albumentations.ai/docs/getting_started/transforms_and_targets/

  • 分为两类:pixel-level transforms 和 spatial-level transforms. 前者像素级变换,mask等不需要动;后者mask同步变换;

变换概率:https://albumentations.ai/docs/getting_started/setting_probabilities/

  • 每种变换都有一个参数p,来控制应用这个变换的概率;
  • 有的变换默认p为1,而有的默认为0.5;
  • 如果有嵌套,比如Compose、OneOf、GaussNoise等,概率相乘;
  • 对于OneOf里面的各个变换概率,会被归一化;

捕获变换的参数:https://albumentations.ai/docs/examples/replay/

  • 用replay可以捕获变换的参数,从而用到多张图上;
  • 也可以用来debug;

多张图片同步变换,比如多张图对应同一个mask,或者一张图对应多个mask:https://albumentations.ai/docs/examples/example_multi_target/

3 代码示例

3.1 测试代码

适用于在本地测试各个变换的实际效果。

    import random
    import cv2
    from matplotlib import pyplot as plt
    import albumentations as A
    
    ## 可视化变换前后图像和对应mask
    def visualize(image, mask, original_image=None, original_mask=None):
        fontsize = 14
    
        if original_image is None and original_mask is None:
            f, ax = plt.subplots(2, 1, figsize=(8, 5))
    
            ax[0].imshow(image, cmap=plt.cm.gray)
            ax[1].imshow(mask, cmap=plt.cm.gray)
        else:
            f, ax = plt.subplots(2, 2, figsize=(8, 5))
    
            ax[0, 0].imshow(original_image, cmap=plt.cm.gray)
            ax[0, 0].set_title('Original image', fontsize=fontsize)
    
            ax[1, 0].imshow(original_mask, cmap=plt.cm.gray)
            ax[1, 0].set_title('Original mask', fontsize=fontsize)
    
            ax[0, 1].imshow(image, cmap=plt.cm.gray)
            ax[0, 1].set_title('Transformed image', fontsize=fontsize)
    
            ax[1, 1].imshow(mask, cmap=plt.cm.gray)
            ax[1, 1].set_title('Transformed mask', fontsize=fontsize)
    
    image = cv2.imread(r"D:\samples\0000.png", cv2.IMREAD_GRAYSCALE)
    mask = cv2.imread(r"D:\samples\0000_mask.png", cv2.IMREAD_GRAYSCALE)
    
    ## 变换示例1:pixel-level transforms,mask不变
    # aug = A.RandomGamma(p=1, gamma_limit=(60, 90))
    # aug = A.RandomBrightnessContrast(p=1, brightness_limit=(-0.1, 0.2), contrast_limit=(-0.4, 0.4))
    # aug = A.CLAHE(p=1, clip_limit=2.0, tile_grid_size=(4, 4))
    # aug = A.MotionBlur(p=1, blur_limit=5)
    # aug = A.GlassBlur(p=1, sigma=0.05, max_delta=1, iterations=1)
    # aug = A.GaussianBlur(p=1, blur_limit=(1, 3))
    
    ## 变换示例2:spatial-level transforms
    # aug = A.ElasticTransform(p=1, alpha=80, sigma=8, alpha_affine=10)
    # aug = A.GridDistortion(p=1, num_steps=5, distort_limit=(-0.3, 0.3))
    # aug = A.OpticalDistortion(distort_limit=1, shift_limit=1, p=1)
    # aug = A.RandomResizedCrop(size=(120,248), scale=(0.5,1.0), ratio=(1.8,2.4), p=1)
    # aug = A.Affine(p=1, scale=(0.9,1.1), translate_percent=None, shear=(-20, 20), rotate=(-40,40))
    aug = A.Perspective(p=1, scale=(0.05, 0.3))
    
    random.seed(9) #固定种子便于复现,实际使用时注掉
    augmented = aug(image=image, mask=mask)
    
    image_elastic = augmented['image']
    mask_elastic = augmented['mask']
    
    print(f"image_elastic.shape {image_elastic.shape}")
    print(f"mask_elastic.shape {mask_elastic.shape}")
    
    visualize(image_elastic, mask_elastic, original_image=image, original_mask=mask)

3.2 实际使用代码

适用于嵌入Pytorch的dataloader,用于数据增广。

  • 里面HorizontalFlip变换发生概率为 0.9 * 0.5;
  • 第一个OneOf中RandomGamma变换发生概率为 0.9 * 0.6 * (1 / (1+2+1));

# Define the transformations
self.transform = A.Compose([
    A.OneOf([
        A.RandomGamma(p=1, gamma_limit=(60, 90)),
        A.RandomBrightnessContrast(p=2, brightness_limit=(-0.1, 0.2), contrast_limit=(-0.4, 0.4)),
        A.CLAHE(p=1, clip_limit=2.0, tile_grid_size=(4, 4))
    ], p=0.6),
    A.OneOf([
        A.MotionBlur(p=1, blur_limit=5),
        A.GlassBlur(p=1, sigma=0.05, max_delta=1, iterations=1),
        A.GaussianBlur(p=1, blur_limit=(1, 3))
    ], p=0.6),
    A.HorizontalFlip(p=0.5),
    A.Affine(p=0.6, scale=(0.9,1.1), translate_percent=None, shear=(-10, 10), rotate=(-30,30)),
    A.RandomResizedCrop(p=0.6, size=(120,248), scale=(0.6,1.0), ratio=(1.9,2.3)),
    A.Perspective(p=0.8, scale=(0.05, 0.3)),
], p=0.9, additional_targets={'image0': 'image'})

#训练时进行数据增广
if self.train:
    transformed = self.transform(image=data['ir'], 
                                 image0=datax['speckle'], 
                                 mask=data['gt'])
    data['ir'] = transformed['image']
    data['speckle'] = transformed['image0']
    data['gt'] = transformed['mask']
posted @ 2024-07-31 22:26  天地辽阔  阅读(142)  评论(0编辑  收藏  举报