TensorFlow自定义数据集:tf.data.Dataset¶
发布于:2020-02-27 | 分类:machine learning
当样本数据大到不能一次性载入内存时,TensorFlow推荐使用tf.data.Dataset创建样本的输入数据流,进而投喂给模型进行训练model.fit(dataset)。本文以带标签(即文件名)的验证码图片数据集为例,记录基本流程及一些备忘点。
基本流程¶
tf.data.Dataset的使用遵循基本的流程 1:
- 从输入数据创建
Dataset,例如from_tensor_slices,list_files - 应用
Dataset变换预处理数据,例如map、filter、shuffle、batch、repeat、cache等等 - 遍历
Dataset并生成数据
针对本例中从文件夹读取图片的问题,我们直接可以从tf.data.Dataset.list_files开始。以下参考了TensorFlow手册 2 中相关内容。
创建源数据集:tf.data.Dataset.list_files¶
该函数返回一个Dataset,其元素为满足给定模式的所有文件路径,并且元素默认 随机、不确定 排列。不确定 指的是每次遍历时得到的顺序都不一样。
如果需要得到固定的顺序,可以设置一个确定的seed或者关闭打乱选项shuffle=False。
list_files(
file_pattern, shuffle=None, seed=None
)
file_pattern需要载入文件的路径模式,例如samples/*.jpgshuffle是否打乱数据集,默认**是**seed用以打乱数据集的随机数种子
数据集变换预处理数据¶
我们已经得到了文件名,但真正需要喂给模型的是图片本身及其标签,所以需要在这个源数据集上进行变换操作。也就是说,Dataset的元素应满足(X, Y)的形式,当包含多特征的输入或输出,例如本例中输出四位的验证码,可以以字典的形式构造(X, Y)元组中的X或Y:
(image_raw_data, {
'label_1': 'A',
'label_2': 'b',
'label_3': 'c',
'label_4': 'D'
})
注意
当定义多输入多输出的模型结构时,输入输出的名称应与此处的定义前后一致 3。
tf.data.Dataset.map¶
TensorFlow提供了map函数将预定义的变换操作map_func作用在Dataset的每一个元素上,然后返回这个新的Dataset。
map(
map_func, num_parallel_calls=None
)
map_func以原来Dataset中的元素为参数的自定义函数,返回新Dataset中相应的元素num_parallel_calls并行处理元素的个数,默认顺序执行
需要注意的是map函数以Graph的形式执行自定义的map_func,因此EagerTensor的性质例如numpy()将不可用。如果非用不可,API文档 1 中提及了使用tf.py_function转换的形式,但是以性能损失为代价。
具体到本例中,
path_pattern = 'samples/*.jpg'
dataset = tf.data.Dataset.list_files(path_pattern).map(
lambda image_path: (
_decode_image(image_path), # load images, perform transformation as X
_decode_labels(image_path) # parse labels as Y
))
tf.data.Dataset.shuffle¶
shuffle打乱当前Dataset并返回乱序后的Dataset,方便链式操作。
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None
)
buffer_size缓冲区大小。元素被依次填入缓冲区,然后从中随机取出以达到打乱效果。因此,buffer_size越大,乱序效果越好,但性能随之下降。seed打乱用的随机数种子reshuffle_each_iteration是否每次遍历时都自动打乱,默认 是。避免不同epoch的训练过程中,Dataset保持一致的顺序。
tf.data.Dataset.batch¶
将一定数量的元素组织为一个batch,得到新的Dataset。同为链式操作。
batch(
batch_size, drop_remainder=False
)
batch_size批次的大小drop_remainder当原来样本数量不能被batch_size整除时,是否丢弃最后剩下的不足一个批次的样本。默认 保留。
此外,Dataset还有一些列实用的操作,例如filter筛选元素、cache提升性能,具体操作API文档 1。
代码汇总¶
import tensorflow as tf
# --------------------------------------------
# create dataset from path pattern
# --------------------------------------------
def create_dataset_from_path(path_pattern,
batch_size=32,
image_size=(60, 120),
label_prefix='labels',
grayscale=False): # load image and convert to grayscale
# create path dataset
# by default, `tf.data.Dataset.list_files` gets filenames
# in a non-deterministic random shuffled order
return tf.data.Dataset.list_files(path_pattern).map(
lambda image_path: _parse_path_function(image_path, image_size, label_prefix, grayscale)
).batch(batch_size)
def _parse_path_function(path, image_size, label_prefix, grayscale):
'''parse image data and labels from path'''
raw_image = open(path, 'rb').read()
labels = tf.strings.substr(path, -8, 4) # path example: b'xxx\abcd.jpg'
# decode image array and labels
image_data = _decode_image(raw_image, image_size, grayscale)
dict_labels = _decode_labels(labels, label_prefix)
return image_data, dict_labels
def _decode_image(image, resize, grayscale):
'''preprocess image with given raw data
- image: image raw data
'''
image = tf.image.decode_jpeg(image, channels=3)
# convert to floats in the [0,1] range.
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, resize)
# RGB to grayscale -> channels=1
if grayscale:
image = tf.image.rgb_to_grayscale(image) # shape=(h, w, 1)
return image # (h, w, c)
def _decode_labels(labels, prefix):
''' this function is used within dataset.map(),
where eager execution is disables by default:
check tf.executing_eagerly() returns False.
So any Tensor.numpy() is not allowed in this function.
'''
dict_labels = {}
for i in range(4):
c = tf.strings.substr(labels, i, 1) # labels example: b'abcd'
label = tf.strings.unicode_decode(c, input_encoding='utf-8') - ord('a')
dict_labels[f'{prefix}{i}'] = label
return dict_labels
至此可以将整个文件夹的图片喂给模型训练了,但是数十万张图片既不方便传输、频繁读文件操作也影响性能,因此下篇将所有数据写入TFRecord文件,然后使用tf.data.TFRecordDataset导入。