深度学习之图像的数据增强
在图像的深度学习中,为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合),一般都会对数据图像进行数据增强,
数据增强,常用的方式,就是旋转图像,剪切图像,改变图像色差,扭曲图像特征,改变图像尺寸大小,增强图像噪音(一般使用高斯噪音,盐椒噪音)等.
但是需要注意,不要加入其他图像轮廓的噪音.
对于常用的图像的数据增强的实现,如下:
1 # -*- coding:utf-8 -*- 2 """数据增强 3 1. 翻转变换 flip 4 2. 随机修剪 random crop 5 3. 色彩抖动 color jittering 6 4. 平移变换 shift 7 5. 尺度变换 scale 8 6. 对比度变换 contrast 9 7. 噪声扰动 noise 10 8. 旋转变换/反射变换 Rotation/reflection 11 author: XiJun.Gong 12 date:2016-11-29 13 """ 14 15 from PIL import Image, ImageEnhance, ImageOps, ImageFile 16 import numpy as np 17 import random 18 import threading, os, time 19 import logging 20 21 logger = logging.getLogger(__name__) 22 ImageFile.LOAD_TRUNCATED_IMAGES = True 23 24 25 class DataAugmentation: 26 """ 27 包含数据增强的八种方式 28 """ 29 30 31 def __init__(self): 32 pass 33 34 @staticmethod 35 def openImage(image): 36 return Image.open(image, mode="r") 37 38 @staticmethod 39 def randomRotation(image, mode=Image.BICUBIC): 40 """ 41 对图像进行随机任意角度(0~360度)旋转 42 :param mode 邻近插值,双线性插值,双三次B样条插值(default) 43 :param image PIL的图像image 44 :return: 旋转转之后的图像 45 """ 46 random_angle = np.random.randint(1, 360) 47 return image.rotate(random_angle, mode) 48 49 @staticmethod 50 def randomCrop(image): 51 """ 52 对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图 53 :param image: PIL的图像image 54 :return: 剪切之后的图像 55 56 """ 57 image_width = image.size[0] 58 image_height = image.size[1] 59 crop_win_size = np.random.randint(40, 68) 60 random_region = ( 61 (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1, 62 (image_height + crop_win_size) >> 1) 63 return image.crop(random_region) 64 65 @staticmethod 66 def randomColor(image): 67 """ 68 对图像进行颜色抖动 69 :param image: PIL的图像image 70 :return: 有颜色色差的图像image 71 """ 72 random_factor = np.random.randint(0, 31) / 10. # 随机因子 73 color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度 74 random_factor = np.random.randint(10, 21) / 10. # 随机因子 75 brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 调整图像的亮度 76 random_factor = np.random.randint(10, 21) / 10. # 随机因1子 77 contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度 78 random_factor = np.random.randint(0, 31) / 10. # 随机因子 79 return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 80 81 @staticmethod 82 def randomGaussian(image, mean=0.2, sigma=0.3): 83 """ 84 对图像进行高斯噪声处理 85 :param image: 86 :return: 87 """ 88 89 def gaussianNoisy(im, mean=0.2, sigma=0.3): 90 """ 91 对图像做高斯噪音处理 92 :param im: 单通道图像 93 :param mean: 偏移量 94 :param sigma: 标准差 95 :return: 96 """ 97 for _i in range(len(im)): 98 im[_i] += random.gauss(mean, sigma) 99 return im 100 101 # 将图像转化成数组 102 img = np.asarray(image) 103 img.flags.writeable = True # 将数组改为读写模式 104 width, height = img.shape[:2] 105 img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma) 106 img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma) 107 img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma) 108 img[:, :, 0] = img_r.reshape([width, height]) 109 img[:, :, 1] = img_g.reshape([width, height]) 110 img[:, :, 2] = img_b.reshape([width, height]) 111 return Image.fromarray(np.uint8(img)) 112 113 @staticmethod 114 def saveImage(image, path): 115 image.save(path) 116 117 118 def makeDir(path): 119 try: 120 if not os.path.exists(path): 121 if not os.path.isfile(path): 122 # os.mkdir(path) 123 os.makedirs(path) 124 return 0 125 else: 126 return 1 127 except Exception, e: 128 print str(e) 129 return -2 130 131 132 def imageOps(func_name, image, des_path, file_name, times=5): 133 funcMap = {"randomRotation": DataAugmentation.randomRotation, 134 "randomCrop": DataAugmentation.randomCrop, 135 "randomColor": DataAugmentation.randomColor, 136 "randomGaussian": DataAugmentation.randomGaussian 137 } 138 if funcMap.get(func_name) is None: 139 logger.error("%s is not exist", func_name) 140 return -1 141 142 for _i in range(0, times, 1): 143 new_image = funcMap[func_name](image) 144 DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name)) 145 146 147 opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"} 148 149 150 def threadOPS(path, new_path): 151 """ 152 多线程处理事务 153 :param src_path: 资源文件 154 :param des_path: 目的地文件 155 :return: 156 """ 157 if os.path.isdir(path): 158 img_names = os.listdir(path) 159 else: 160 img_names = [path] 161 for img_name in img_names: 162 print img_name 163 tmp_img_name = os.path.join(path, img_name) 164 if os.path.isdir(tmp_img_name): 165 if makeDir(os.path.join(new_path, img_name)) != -1: 166 threadOPS(tmp_img_name, os.path.join(new_path, img_name)) 167 else: 168 print 'create new dir failure' 169 return -1 170 # os.removedirs(tmp_img_name) 171 elif tmp_img_name.split('.')[1] != "DS_Store": 172 # 读取文件并进行操作 173 image = DataAugmentation.openImage(tmp_img_name) 174 threadImage = [0] * 5 175 _index = 0 176 for ops_name in opsList: 177 threadImage[_index] = threading.Thread(target=imageOps, 178 args=(ops_name, image, new_path, img_name,)) 179 threadImage[_index].start() 180 _index += 1 181 time.sleep(0.2) 182 183 184 if __name__ == '__main__': 185 threadOPS("/home/pic-image/train/12306train", 186 "/home/pic-image/train/12306train3")
编程是一种快乐,享受代码带给我的乐趣!!!