TensorFlow自定义数据集:tf.data.Dataset

发布于:2020-02-27 | 分类:machine learning


当样本数据大到不能一次性载入内存时,TensorFlow推荐使用tf.data.Dataset创建样本的输入数据流,进而投喂给模型进行训练model.fit(dataset)。本文以带标签(即文件名)的验证码图片数据集为例,记录基本流程及一些备忘点。

基本流程

tf.data.Dataset的使用遵循基本的流程 1

  • 从输入数据创建Dataset,例如from_tensor_sliceslist_files
  • 应用Dataset变换预处理数据,例如mapfiltershufflebatchrepeatcache等等
  • 遍历Dataset并生成数据

针对本例中从文件夹读取图片的问题,我们直接可以从tf.data.Dataset.list_files开始。以下参考了TensorFlow手册 2 中相关内容。

创建源数据集:tf.data.Dataset.list_files

该函数返回一个Dataset,其元素为满足给定模式的所有文件路径,并且元素默认 随机不确定 排列。不确定 指的是每次遍历时得到的顺序都不一样。

如果需要得到固定的顺序,可以设置一个确定的seed或者关闭打乱选项shuffle=False

list_files(
    file_pattern, shuffle=None, seed=None
)
  • file_pattern 需要载入文件的路径模式,例如samples/*.jpg
  • shuffle 是否打乱数据集,默认**是**
  • seed 用以打乱数据集的随机数种子

数据集变换预处理数据

我们已经得到了文件名,但真正需要喂给模型的是图片本身及其标签,所以需要在这个源数据集上进行变换操作。也就是说,Dataset的元素应满足(X, Y)的形式,当包含多特征的输入或输出,例如本例中输出四位的验证码,可以以字典的形式构造(X, Y)元组中的XY

(image_raw_data, {
  'label_1': 'A',
  'label_2': 'b',
  'label_3': 'c',
  'label_4': 'D'
  })

注意

当定义多输入多输出的模型结构时,输入输出的名称应与此处的定义前后一致 3

tf.data.Dataset.map

TensorFlow提供了map函数将预定义的变换操作map_func作用在Dataset的每一个元素上,然后返回这个新的Dataset

map(
    map_func, num_parallel_calls=None
)
  • map_func 以原来Dataset中的元素为参数的自定义函数,返回新Dataset中相应的元素
  • num_parallel_calls 并行处理元素的个数,默认顺序执行

需要注意的是map函数以Graph的形式执行自定义的map_func,因此EagerTensor的性质例如numpy()将不可用。如果非用不可,API文档 1 中提及了使用tf.py_function转换的形式,但是以性能损失为代价。

具体到本例中,

path_pattern = 'samples/*.jpg'

dataset = tf.data.Dataset.list_files(path_pattern).map(
        lambda image_path: (
          _decode_image(image_path),  # load images, perform transformation as X
          _decode_labels(image_path)  # parse labels as Y
    ))

tf.data.Dataset.shuffle

shuffle打乱当前Dataset并返回乱序后的Dataset,方便链式操作。

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)
  • buffer_size 缓冲区大小。元素被依次填入缓冲区,然后从中随机取出以达到打乱效果。因此,buffer_size越大,乱序效果越好,但性能随之下降。
  • seed 打乱用的随机数种子
  • reshuffle_each_iteration 是否每次遍历时都自动打乱,默认 。避免不同epoch的训练过程中,Dataset保持一致的顺序。

tf.data.Dataset.batch

将一定数量的元素组织为一个batch,得到新的Dataset。同为链式操作。

batch(
    batch_size, drop_remainder=False
)
  • batch_size 批次的大小
  • drop_remainder 当原来样本数量不能被batch_size整除时,是否丢弃最后剩下的不足一个批次的样本。默认 保留

此外,Dataset还有一些列实用的操作,例如filter筛选元素、cache提升性能,具体操作API文档 1

代码汇总

import tensorflow as tf

# --------------------------------------------
# create dataset from path pattern
# --------------------------------------------
def create_dataset_from_path(path_pattern, 
    batch_size=32, 
    image_size=(60, 120), 
    label_prefix='labels',
    grayscale=False): # load image and convert to grayscale
    # create path dataset
    # by default, `tf.data.Dataset.list_files` gets filenames 
    # in a non-deterministic random shuffled order
    return tf.data.Dataset.list_files(path_pattern).map(
        lambda image_path: _parse_path_function(image_path, image_size, label_prefix, grayscale)
    ).batch(batch_size)


def _parse_path_function(path, image_size, label_prefix, grayscale):
    '''parse image data and labels from path'''
    raw_image = open(path, 'rb').read()
    labels = tf.strings.substr(path, -8, 4) # path example: b'xxx\abcd.jpg'
    # decode image array and labels
    image_data = _decode_image(raw_image, image_size, grayscale)
    dict_labels = _decode_labels(labels, label_prefix)

    return image_data, dict_labels


def _decode_image(image, resize, grayscale):
    '''preprocess image with given raw data
        - image: image raw data
    '''
    image = tf.image.decode_jpeg(image, channels=3)

    # convert to floats in the [0,1] range.
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, resize)

    # RGB to grayscale -> channels=1
    if grayscale:
        image = tf.image.rgb_to_grayscale(image) # shape=(h, w, 1)

    return image # (h, w, c)


def _decode_labels(labels, prefix):
    ''' this function is used within dataset.map(), 
        where eager execution is disables by default:
            check tf.executing_eagerly() returns False.
        So any Tensor.numpy() is not allowed in this function.
    '''
    dict_labels = {}
    for i in range(4):
        c = tf.strings.substr(labels, i, 1) # labels example: b'abcd'
        label = tf.strings.unicode_decode(c, input_encoding='utf-8') - ord('a')
        dict_labels[f'{prefix}{i}'] = label

    return dict_labels

至此可以将整个文件夹的图片喂给模型训练了,但是数十万张图片既不方便传输、频繁读文件操作也影响性能,因此下篇将所有数据写入TFRecord文件,然后使用tf.data.TFRecordDataset导入。