tensorflow数据加载与前处理
一、数据加载
1、加载文件夹下图片数据:不同类图片放在不同的文件夹下
batch_size = 32 #next(iter(train_ds)),一次迭代32张图片
img_height = 224
img_width = 224 train_ds = tf.keras.preprocessing.image_dataset_from_directory(
str(data_root), validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size #data_root是分类图片路径
)
print(train_ds.class_names) #图片类别/文件夹个数
2、加载tensorflow_datasets数据:自动将数据分成训练数据、验证数据、测试数据
import tensorflow_datasets as tfds
(train_ds, val_ds, test_ds), metadata = tfds.load( 'tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True, )
print(metadata.features['label'].num_classes) #dataset有几个类别
get_label_name = metadata.features['label'].int2str
get_label_name(0) #label对应的类别
二、数据前处理
https://tensorflow.google.cn/tutorials/images/data_augmentation?hl=zh_cn
1、用lambda表达式对整个数据集进行处理(5中也涉及)
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255) #数据缩放到[0, 1] train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
2、通用数据增强层
1)layers.Lambda层进行图像增强
def random_invert_img(x, p=0.5): if tf.random.uniform([]) < p: x = (255-x) else: x return x
def random_invert(factor=0.5): return layers.Lambda(lambda x: random_invert_img(x, factor))
random_invert(image) #layers.Lambda会自动调用绑定的函数处理图片的每个像素
2)通过subclassing新建一个层
class RandomInvert(layers.Layer): def __init__(self, factor=0.5, **kwargs): super().__init__(**kwargs) self.factor = factor def call(self, x): return random_invert_img(x, 0.3) #random_invert_img()函数上面有
_ = plt.imshow(RandomInvert()(image)[0])
3、tf.image、tf.data实现数据增强管道或层
1 flipped = tf.image.flip_left_right(image) 2 grayscaled = tf.image.rgb_to_grayscale(image); tf.squeeze(grayscaled) 3 saturated = tf.image.adjust_saturation(image, 3) 4 bright = tf.image.adjust_brightness(image, 0.4) 5 cropped = tf.image.central_crop(image, central_fraction= 0.5) 6 rotated = tf.image.rot90(image)
4、随便变换()
tf.image.stateless_random_brightness tf.image.stateless_random_contrast tf.image.stateless_random_crop tf.image.stateless_random_flip_left_right tf.image.stateless_random_flip_up_down tf.image.stateless_random_hue tf.image.stateless_random_jpeg_quality tf.image.stateless_random_saturation 例如: seed = (i, 0) # tuple of size (2,);相同种子,同样结果 stateless_random_brightness = tf.image.stateless_random_brightness( image, max_delta=0.95, seed=seed) #随机改变图片亮度 stateless_random_contrast = tf.image.stateless_random_contrast( image, lower=0.1, upper=0.9, seed=seed) #随机改变图片对比度 stateless_random_crop = tf.image.stateless_random_crop( image, size=[210, 300, 3], seed=seed) #随机扣取图片
5、数据集增强
1 #下载数据集 2 (train_datasets, val_ds, test_ds), metadata = tfds.load( 3 'tf_flowers', 4 split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], 5 with_info=True, 6 as_supervised=True, 7 ) 8 #图像缩放 9 def resize_and_rescale(image, label): 10 image = tf.cast(image, tf.float32) 11 image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE]) 12 image = (image / 255.0) 13 return image, label 14 #带seed的增强函数 15 def augment(image_label, seed): 16 image, label = image_label 17 image, label = resize_and_rescale(image, label) 18 image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6) 19 # Make a new seed 20 new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :] 21 # Random crop back to the original size 22 image = tf.image.stateless_random_crop( 23 image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed) 24 # Random brightness 25 image = tf.image.stateless_random_brightness( 26 image, max_delta=0.5, seed=new_seed) 27 image = tf.clip_by_value(image, 0, 1) 28 return image, label
1)利用tf.data.experimental.Counter()
(1)?创建一个计数器,并使用(counter, counter)压缩数据集,这将确保数据集中的每个图像斗鱼基于计数器的唯一值相关联
Create a tf.data.experimental.Counter()
object (let's call it counter
) and zip
the dataset with (counter, counter)
. This will ensure that each image in the dataset gets associated with a unique value (of shape (2,)
) based on counter
which later can get passed into the augment
function as the seed
value for random transformations.
1 counter = tf.data.experimental.Counter() 2 train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))
(2)利用映射函数处理数据集
train_ds = ( train_ds .shuffle(1000) .map(augment, num_parallel_calls=AUTOTUNE) .batch(batch_size) .prefetch(AUTOTUNE) ) val_ds = ( val_ds .map(resize_and_rescale, num_parallel_calls=AUTOTUNE) .batch(batch_size) .prefetch(AUTOTUNE) ) test_ds = ......
2)利用tf.random.Generator
(1)创建映射函数
Note: tf.random.Generator
objects store RNG state in a tf.Variable
, which means it can be saved as a checkpoint or in a SavedModel. For more details, please refer to Random number generation.
rng = tf.random.Generator.from_seed(123, alg='philox') # Create a generator #映射函数,映射过程如上一致 def f(x, y): seed = rng.make_seeds(2)[0] image, label = augment((x, y), seed) return image, label # A wrapper function for updating seeds