9. TFRecords分析、存取 – Python量化投资

9. TFRecords分析、存取

TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,更方便复制和移动。
为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中。

TFRecords文件分析

文件格式:*.tfrecords
写入文件内容:Example协议块

TFRecords存储

  • 1、建立TFRecord存储器
    tf.python_io.TFRecordWriter(path)
    写入tfrecords文件
    path: TFRecords文件的路径
    return:存储到tfrecords的写入器
    方法:
    write(record): 向文件中写入一个字符串记录。 字符串为一个序列化的Example,Example.SerializeToString()
    close(): 关闭文件写入器
  • 2、构造每个样本的Example协议块
    tf.train.Example(features=None)
    写入tfrecords文件
    features:tf.train.Features类型的特征实例
    return:example格式协议块
    tf.train.Features(feature=None)
    构建每个样本的信息键值对
    feature:字典数据,key为要保存的名字,value为tf.train.Feature实例
    return:Features类型
    tf.train.Feature(**options)
    **options:例如
    bytes_list=tf.train.BytesList(value=[Bytes])
    int64_list=tf.train.Int64List(value=[Value])
    tf.train.Int64List(value=[Value])
    tf.train.BytesList(value=[Bytes])
    tf.train.FloatList(value=[value])

TFRecords读取方法

同文件阅读器流程,中间需要解析过程
解析TFRecords的example协议内存块
tf.parse_single_example(serialized, features=None, name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return: 一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是 float32, int64, string

features = tf.parse_single_example(value, features={
    'image': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64)
})

存储代码

import tensorflow as tf
import os
# 1. 创建文件队列
filename_list = os.listdir('./cifar/cifar-10-batches-bin')
file_list = [os.path.join('./cifar/cifar-10-batches-bin', filename) for filename in filename_list if filename[-3:] == 'bin']
file_queue = tf.train.string_input_producer(file_list)
# 2. 创建读取器读取文件
reader = tf.FixedLengthRecordReader(32*32*3 + 1)
k, v = reader.read(file_queue)
# 3. 解码读取的内容 uint8类型
v_decode = tf.decode_raw(v, tf.uint8)  # 返回 A `Tensor` of type `out_type`
label = tf.slice(v_decode, [0], [1])
image = tf.slice(v_decode, [0], [32*32*3])
# 4. 批量处理
label_batch, image_batch = tf.train.batch([label, image], batch_size=10, num_threads=1, capacity=10)
image_batch = tf.reshape(image_batch, [10, 32, 32, 3])
# 创建存储到tfrecords的写入器
writer = tf.python_io.TFRecordWriter('./save_2_tfrecords.tfrecords')
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord=coord)
    # 存储到tfrecords类型文件 每次写入一个example协议块 每次写入一个
    for i in range(10):
        writer.write(tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_batch[i].eval().tostring()])),
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label_batch[i].eval()))
        })).SerializeToString())
    writer.close()
    print(sess.run([label_batch, image_batch]))
    coord.request_stop()
    coord.join(threads)

读取代码

import tensorflow as tf
# 1. 创建文件队列
file_list = ['./save_2_tfrecords.tfrecords']
file_queue = tf.train.string_input_producer(file_list)
# 2. 创建tfreocrds的读取器并读取内容
reader = tf.TFRecordReader()
k, v = reader.read(file_queue)
# 3. 解析example块的字节内容
features = tf.parse_single_example(v, features={
    'image': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64)
})
# 4. 解码
image = tf.decode_raw(features['image'], tf.uint8)
image = tf.reshape(image, [32, 32, 3])
label = tf.cast(features['label'], tf.int32)
image_batch, label_batch = tf.train.batch([image, label], batch_size=10, num_threads=1, capacity=10)
print(image_batch, label_batch)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess)
    print(sess.run([image_batch, label_batch]))
    coord.request_stop()
    coord.join(threads)

https://www.jianshu.com/p/3dc2c77415d9

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
0 条回复 A 作者 M 管理员
    所有的伟大,都源于一个勇敢的开始!
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论