凯鲁嘎吉
用书写铭记日常,最迷人的不在远方

Python小练习:数据增广Random Shift

作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/

DrQ-v2(Mastering visual continuous control: Improved data-augmented reinforcement learning)中提到一种新的数据增广方式Random Shift,下面以图像增广为例,来看看该增广算子的实现原理。

首先在Python源码的目录下新建一个img的文件夹,里面存放待增广的图片,jpg格式,然后运行image_aug.py,增广前后的对比图保存在和源码同一目录下。

1. image_aug.py

  1 # -*- coding: utf-8 -*-
  2 # Author:凯鲁嘎吉 Coral Gajic
  3 # https://www.cnblogs.com/kailugaji/
  4 # 批量对图像进行增广(增广算子为随机平移)
  5 '''
  6 增广算子来源于DrQ-v2:
  7     Yarats D, Fergus R, Lazaric A, et al.
  8     Mastering visual continuous control: Improved data-augmented reinforcement learning[J].
  9     arXiv preprint arXiv:2107.09645, 2021.
 10 '''
 11 
 12 import os
 13 import imageio
 14 from PIL import Image
 15 import numpy as np
 16 import matplotlib.pyplot as plt
 17 import torch
 18 from torch import nn
 19 import torch.nn.functional as F
 20 from pylab import *
 21 import torchvision.transforms as transforms
 22 
 23 # Random Shift
 24 # 尺寸为 A × B 的图像每条边填充 pad 个像素(通过重复边界像素),然后随机裁剪回原始 A × B 尺寸
 25 class RandomShiftsAug(nn.Module):
 26     '''
 27         https://github.com/facebookresearch/drqv2/blob/main/drqv2.py
 28     '''
 29     def __init__(self, pad=4,aug=True):
 30         # pad: padding 填充
 31         super().__init__()
 32         self.pad = pad
 33         self.aug = aug
 34 
 35     def forward(self, x):
 36         if self.aug:
 37             n, _, h, w = x.size()
 38             # n,c,h,w分别表示batch数,通道数,高,宽
 39             padding = tuple([self.pad] * 4)
 40             # 分别对前,后做多少位的padding操作
 41             # 例如tuple([4] * 4)=> (4, 4, 4, 4),上下左右四个方位都做pad的填充
 42             x = F.pad(x, padding, 'replicate')
 43             # 对高维tensor的形状补齐操作
 44             # replicate​​​:使用tensor自身边界值补齐指定的维度。对于数据​​012​​​,结果可以为​​0001222​
 45             eps = 1.0 / (w + 2 * self.pad)
 46             # 若w=84, pad=4, 则w + 2 * self.pad=92, eps=0.010869565217391304
 47             arange = torch.linspace(-1.0 + eps, # 区间左侧
 48                                     1.0 - eps, # 区间右侧
 49                                     w + 2 * self.pad, # 92个点
 50                                     device=x.device,
 51                                     dtype=x.dtype)[:w]
 52             # 线性间距向量
 53             # torch.linspace(start, end, steps=100, out=None) → Tensor
 54             # 返回一个1维张量,包含在区间start和end上均匀间隔的step个点
 55             # 输出张量的长度由steps决定
 56             # 例如:生成0到10的5个数构成的等差数列
 57             # b = torch.linspace(0,10,steps=5)
 58             # tensor([ 0.0000,  2.5000,  5.0000,  7.5000, 10.0000])
 59             eps_h = 1.0 / (h + 2 * self.pad) # 若h=84, pad=4, 则eps_h=0.010869565217391304
 60             arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
 61             # 扩充第一个维度
 62             # 重复h行
 63             # 扩充第三个维度
 64             arange_w = torch.linspace(-1.0 + eps_h,
 65                                     1.0 - eps_h,
 66                                     h + 2 * self.pad,
 67                                     device=x.device,
 68                                     dtype=x.dtype)[:h]
 69             arange_w = arange_w.unsqueeze(1).repeat(1, w).unsqueeze(2)
 70             # 扩充第二个维度
 71             # 重复w列
 72             # 扩充第三个维度
 73             # arange_w = arange_w.unsqueeze(0).repeat(w, 1).unsqueeze(2)
 74             base_grid = torch.cat([arange, arange_w], dim=2)
 75             # 将两个张量按第三个维度拼接在一起
 76             base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
 77             # n个数据,重复n次操作
 78 
 79             shift = torch.randint(0,
 80                                 2 * self.pad + 1, # pad=4, 值为9
 81                                 size=(n, 1, 1, 2),
 82                                 device=x.device,
 83                                 dtype=x.dtype)
 84             shift[:,:,:,0] *= 2.0 / (w + 2 * self.pad)
 85             shift[:,:,:,1] *= 2.0 / (h + 2 * self.pad)
 86             grid = base_grid + shift
 87             # 随机平移
 88             # 在 random shift 之后还使用了 bilinear interpolation
 89             return F.grid_sample(x, # x.shape: torch.Size([1, 3, 420, 720])
 90                                 grid, # grid.shape: torch.Size([1, 300, 600, 2])
 91                                 padding_mode='zeros',
 92                                 align_corners=False)
 93             # 输出: torch.Size([1, 3, 300, 600])
 94         # 应用双线性插值,把输入的tensor转换为指定大小
 95         # 参考:https://betheme.net/qianduan/43027.html?action=onClick
 96         # 给定维度为(N,C,Hin,Win) 的input,维度为(N,Hout,Wout,2) 的grid
 97         # 则该函数output的维度为(N,C,Hout,Wout)
 98         # padding_mode表示当grid中的坐标位置超出边界时像素值的填充方式
 99         # 如果为zeros,则表示一旦grid坐标超出边界,则用0去填充输出特征图的相应位置元素
100         else:
101             return x
102 
103 path = "./img" # 打开存放图像的文件夹
104 dirs = os.listdir(path) # ['1.jpg', '2.jpg', '3.jpg']
105 len_dir = len(dirs) # len_dir张图片
106 count=0
107 # 文件夹的每一幅图像都执行增广操作
108 for i in dirs:
109     image_pad = imageio.imread(os.path.join(path,i)) # i: 'xxx.jpg'
110     image_pad = Image.fromarray(image_pad).resize((600, 300)) # 重新调整图像尺寸
111     transf = transforms.ToTensor() # 将原始数据形式(图像)转换成tensor
112     outs = transf(image_pad) # tensor数据格式是torch(C,H,W)
113     # torch.Size([3, 300, 600]) 和resize里面的正好相反
114     outs = outs.unsqueeze(0)  # 扩充第一个维度 torch.Size([1, 3, 300, 600])
115     shift = RandomShiftsAug(pad=60)
116     # Random Shift
117     # 图像每个边填充pad,然后再随机裁剪成600*300
118     outs_1 = shift(outs)
119     # outs.shape: torch.Size([1, 3, 300, 600])
120     # outs_1.shape: torch.Size([1, 3, 300, 600])
121     outs_1 = outs_1.numpy()  # tensor转换成numpy
122     outs_1 = np.transpose(outs_1[0], [1, 2, 0]) # 取出第1幅图像(3, 300, 600)
123     # 交换维度,把第一个维度放在最后,变成(300, 600, 3)
124     # 展示图片
125     # 原图
126     temp1 = int(2*count+1) # 竖着排
127     plt.subplot(len_dir, 2, temp1) # len_dir行,2列,第temp1个图片
128     plt.imshow(image_pad)
129     plt.axis('off')
130     plt.rcParams['font.sans-serif'] = ['KaiTI']
131     plt.rcParams['axes.unicode_minus'] = False
132     if count == 0:
133         plt.title("增广前")
134     # 增广后
135     plt.subplot(len_dir, 2, temp1+1)
136     plt.imshow(outs_1)
137     if count == 0:
138         plt.title("增广后")
139     plt.axis('off')
140     plt.subplots_adjust(wspace=0.05, hspace=0.05)
141     count = count + 1
142     print('---正在处理第%d张图片---' % count)
143 
144 # plt.tight_layout()
145 plt.savefig('Aug_image.png', bbox_inches='tight', pad_inches=0.0, dpi=1500)
146 plt.show()
147 print('-*-^_^-*-图像处理完成-*-^_^-*-')

2. 结果

补充:DrQ(Image augmentation is all you need: Regularizing deep reinforcement learning from pixels)里面关于Random Shift的实现:

 1 import os
 2 import imageio.v2 as imageio
 3 from PIL import Image
 4 import numpy as np
 5 import matplotlib.pyplot as plt
 6 import torch
 7 from torch import nn
 8 import torch.nn.functional as F
 9 from pylab import *
10 import torchvision.transforms as transforms
11 import kornia
12 
13 pad = 60
14 path = "./img" # 打开存放图像的文件夹
15 dirs = os.listdir(path) # ['1.jpg', '2.jpg', '3.jpg']
16 len_dir = len(dirs) # len_dir张图片
17 count=0
18 # 文件夹的每一幅图像都执行增广操作
19 for i in dirs:
20     image_pad = imageio.imread(os.path.join(path,i)) # i: 'xxx.jpg'
21     image_pad = Image.fromarray(image_pad).resize((600, 300)) # 重新调整图像尺寸
22     transf = transforms.ToTensor() # 将原始数据形式(图像)转换成tensor
23     outs = transf(image_pad) # tensor数据格式是torch(C,H,W)
24     # torch.Size([3, 300, 600]) 和resize里面的正好相反
25     outs = outs.unsqueeze(0)  # 扩充第一个维度 torch.Size([1, 3, 300, 600])
26     shift = nn.Sequential(nn.ReplicationPad2d(pad),
27                           kornia.augmentation.RandomCrop((300, 600)))
28     # Random Shift
29     # 图像每个边填充pad,然后再随机裁剪成600*300
30     outs_1 = shift(outs)
31     outs_1 = outs_1.numpy()  # tensor转换成numpy
32     outs_1 = np.transpose(outs_1[0], [1, 2, 0]) # 取出第1幅图像(3, 300, 600)
33     # 交换维度,把第一个维度放在最后,变成(300, 600, 3)
34     # 展示图片
35     # 原图
36     temp1 = int(2*count+1) # 竖着排
37     plt.subplot(len_dir, 2, temp1) # len_dir行,2列,第temp1个图片
38     plt.imshow(image_pad)
39     plt.axis('off')
40     if count == 0:
41         plt.title("Before")
42     # 增广后
43     plt.subplot(len_dir, 2, temp1+1)
44     plt.imshow(outs_1)
45     if count == 0:
46         plt.title("After")
47     plt.axis('off')
48     plt.subplots_adjust(wspace=0.05, hspace=0.05)
49     count = count + 1
50     print('---正在处理第%d张图片---' % count)
51 
52 # plt.tight_layout()
53 plt.savefig('Aug_image2.png', bbox_inches='tight', pad_inches=0.0, dpi=1500)
54 plt.show()
55 print('-*-^_^-*-图像处理完成-*-^_^-*-')

结果:

参考:关于torch.nn.functional.grid_sample函数的说明(F.grid_sample)

posted on 2023-03-21 08:20  凯鲁嘎吉  阅读(403)  评论(0编辑  收藏  举报