keras_preprocessing参数详解
keras_preprocessing.image.image_data_generator.ImageDataGenerator.flow_from_directory()
获取目录路径并生成一批增强数据。
def flow_from_directory(self,
directory: Any,
target_size: Tuple[int, int] = (256, 256),
color_mode: str = 'rgb',
classes: Any = None,
class_mode: str = 'categorical',
batch_size: int = 32,
shuffle: bool = True,
seed: Any = None,
save_to_dir: Any = None,
save_prefix: str = '',
save_format: str = 'png',
follow_links: bool = False,
subset: Any = None,
interpolation: str = 'nearest') -> DirectoryIterator
参数说明:
-
目录(directory):字符串,目标目录的路径。每个类应该包含一个子目录。生成器中将包含每个子目录树中的任何PNG、JPG、BMP、PPM或TIF图像。有关详细信息,请参阅(https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)。
-
目标大小(target_size):整数的元组(高度、宽度)。默认值:(256,256)。将调整找到的所有图像的尺寸。
-
颜色_模式(color_mode):“灰度”、“rgb”、“rgba”之一。默认值:“rgb”。是否将图像转换为具有1、3或4个通道。
-
类列表(classes):类子目录的可选列表。(例如,“狗”、“猫”)。默认值:无。如果没有提供,将根据目录下的子目录名称/结构自动推断类列表,其中每个子目录将被视为不同的类(将映射到标签索引的类的顺序将是字母数字)。包含从类名到类索引的映射的字典可以通过属性class_indexes获得。
-
类模式:确定返回的标签数组的类型:“分类(“categorical”)”、“二进制(“binary”)”、“稀疏(“sparse”)”、“输入(“input”)”、“无(None)”模式之一。默认值:“分类”。
- “categorial”则是2维one-hot编码标签;
- “binary”则是一维二进制标签;
- “sparse”是1维整数标签;
- “input”是与输入图像相同的图像(主要用于自动编码器);
- None,则不会返回任何标签(生成器将只生成一批图像数据,与model.predict_generator()一起使用非常有用)。请注意,在类模式为“无”的情况下,数据仍然需要驻留在目录的子目录中,以便它正常工作。
-
批次大小(batch_size):数据批次的大小(默认值:32)。
-
随机播放(shuffle):是否随机播放数据(默认值:True)。如果设置为False,则按字母数字顺序对数据进行排序。
-
种子(seed):可选的随机种子,用于洗牌和转换。
-
保存路径(save_to_dir):none或str(默认值:none)。这允许您选择指定一个目录,将生成的增强图片保存到该目录(对于可视化所做的操作很有用)。
-
保存前缀(save_prefix): Str. Prefix用于保存图片的文件名(仅当设置了save_to_dir时才相关)。
-
保存格式(save_format):“png”、“jpeg”之一(仅当设置了save_to_dir时才相关)。默认值:“png”。
-
跟随链接(follow_links):是否跟随类子目录内的符号链接(默认值:False)。
-
子集(subset):如果在ImageDataGenerator中设置了验证分割(validation_split),则为数据的子集(训练集"training" 或验证集"validation")。
-
插值(interpolation):当目标大小与加载的图像大小不同时,用于对图像重新采样的插值方法。支持的方法有最近邻"nearest"、双线性"bilinear"和双三次"bicubic"。如果安装了PIL版本1.1.3或更高版本,还支持“lanczos”。如果安装了PIL 3.4.0或更高版本,还支持“Box”和“Hamming”。默认情况下,使用“nearest”。
返回值
一个产生(x,y)元组的目录迭代器(DirectoryIterator)。
其中x是包含一批(batch_size,* target_size,channels)类型的图像的numpy数组,y是对应标签的numpy数组。