前文使用tf.data.Dataset.list_files创建文件路径Dataset并逐步载入图片数据的方式解决无法一次性加载所有图片数据到内存的问题,但依然存在频繁读文件的问题。本文参考TensorFlow官方文档读写图片例子 1,将原来的图片数据以TFRecord文件形式进行存储,并重新载入Dataset和解析数据。


TFRecord是一种存储二进制记录序列的简单格式,可以有效地进行 线性读取 数据。从创建到使用TFRecord文件的全流程:

  • 获取源数据
  • 写入TFRecord文件
  • 读取和解析TFRecord文件

本文侧重点在获取源数据和拆分训练/测试数据集,读写TFRecord则直接参考官方示例 1 的标准流程。



  • takeskip组合最为直接,但是需要知道数据集大小
  • shard可以按间隔挑出一个子集,但是无法得到剩下的部分



以下参考思路 2

  • enumerate得到元素的序号组合(index, element)
  • filter筛选序号,利用两个相反的判断拆分为训练/测试两组(此时元素为(index, element)
  • map分别从拆分的两组中去掉辅助作用的序号
def _split_train_test(file_pattern, test_rate, buffer_size):
    # by default, tf.data.Dataset.list_files always shuffles order during iteration
    # so set it false explicitly
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)

    # shuffle first and stop shuffling during each iteration
    # buffer_size is reccommanded to be larger than dataset size
    dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=False)

    # split train / test sets
    if test_rate:
        # define split interval
        sep = int(1.0/test_rate)
        is_test = lambda x, y: x % sep == 0
        is_train = lambda x, y: not is_test(x, y)
        recover = lambda x,y: y

        # split train/test set, and reset buffle mode: 
        # keep shuffle order different during iteration
        test_dataset = dataset.enumerate(start=1).filter(is_test).map(recover)
        train_dataset = dataset.enumerate(start=1).filter(is_train).map(recover)
        test_dataset, test_dataset = dataset, None

    return train_dataset, test_dataset

写入TFRecord 文件


  • 准备tf.train.Example格式数据 3,类似于字典{"string": tf.train.Feature}


    def _image_example(path):
        """Create a dictionary with features: image raw data, label
            path: image path, e.g. b'test\\lcrh.jpg
        # get image raw data and labels
        image_string = open(path, 'rb').read()
        image_labels = tf.strings.substr(path, -8, 4)
        # preparation for tf.train.Example
        feature = {
        'labels'   : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_labels.numpy()])),
        'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string]))
        return tf.train.Example(features=tf.train.Features(feature=feature))
  • tf.io.TFRecordWritertf.train.Example序列化的值写入文件

    # --------------------------------------------
    # Write raw image data to `images.tfrecords`
    # --------------------------------------------
    def write_images_to_tfrecord(image_file_pattern, # e.g. 'samples/*.jpg'
        dir_record, # folder for storing TFRecord files
        split_test=0.1,  # split train/test rate
        ''' load images and save to TFRecord file
                - get image raw data and labels (filename)
                - split into train / test sets
                - write to TFRecord file
        # split path dataset
        train_dataset, test_dataset = _split_train_test(image_file_pattern, split_test, buffer_size)
        # read image in train set and save to TFRecord file
        train_record = os.path.join(dir_record, f'{prefix_record}_train.tfrecords')
        with tf.io.TFRecordWriter(train_record) as writer:
            for path in train_dataset.as_numpy_iterator():
                tf_example = _image_example(path)
        # read image in test set and save to TFRecord file
        if test_dataset:
            test_record = os.path.join(dir_record, f'{prefix_record}_test.tfrecords')
            with tf.io.TFRecordWriter(test_record) as writer:
                for path in test_dataset.as_numpy_iterator():
                    tf_example = _image_example(path)
            test_record = None
        return train_record, test_record


  • 读入TFRecord文件得到tf.data.TFRecordDataset对象,继承自tf.data.Dataset


    # --------------------------------------------
    # create dataset from TFRecord file
    # --------------------------------------------
    def create_dataset_from_tfrecord(record_file, 
        image_size=(60, 120), 
        grayscale=False): # load image and convert to grayscale
        '''create image/labels dataset from TFRecord file'''          
        return tf.data.TFRecordDataset(record_file).map(
            lambda example_proto: _parse_image_function(example_proto, image_size, label_prefix, grayscale),
            num_parallel_calls=tf.data.experimental.AUTOTUNE # -1 any available CPUs
  • TFRecordDataset中的元素是序列化后的tf.train.Example,配合tf.io.parse_single_example进行解析。


    def _parse_image_function(example_proto, image_size, label_prefix, grayscale):
        '''Parse the input tf.Example protocal using the dictionary describing the features'''
        image_feature_description = {
            'labels'   : tf.io.FixedLenFeature([], tf.string),
            'image_raw': tf.io.FixedLenFeature([], tf.string)
        image_features = tf.io.parse_single_example(example_proto, image_feature_description)
        # decode image array and labels
        image_data = _decode_image(image_features['image_raw'], image_size, grayscale)
        dict_labels = _decode_labels(image_features['labels'], label_prefix)
        return image_data, dict_labels
  • 最后转换图片二进制码为RGB数据、转换验证码各个字符为相应数字编码,作为模型训练的数据。

    def _decode_image(image, resize, grayscale):
        '''preprocess image with given raw data
            - image: image raw data
        image = tf.image.decode_jpeg(image, channels=3)
        # convert to floats in the [0,1] range.
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, resize)
        # RGB to grayscale -> channels=1
        if grayscale:
            image = tf.image.rgb_to_grayscale(image) # shape=(h, w, 1)
        return image # (h, w, c)
    def _decode_labels(labels, prefix):
        ''' this function is used within dataset.map(), 
            where eager execution is disables by default:
                check tf.executing_eagerly() returns False.
            So any Tensor.numpy() is not allowed in this function.
        dict_labels = {}
        for i in range(4):
            c = tf.strings.substr(labels, i, 1) # labels example: b'abcd'
            label = tf.strings.unicode_decode(c, input_encoding='utf-8') - ord('a')
            dict_labels[f'{prefix}{i}'] = label
        return dict_labels