tfrecords格式是什么,写入与读取是怎样的
Admin 2022-09-16 群英技术资讯 659 次浏览
这篇文章给大家分享的是tfrecords格式是什么,写入与读取是怎样的。小编觉得挺实用的,因此分享给大家做个参考,文中的介绍得很详细,而要易于理解和学习,有需要的朋友可以参考,接下来就跟随小编一起了解看看吧。前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚。今天我决定好好了解以下tfrecords文件的构造。
tfrecords是一种二进制编码的文件格式,tensorflow专用。能将任意数据转换为tfrecords。更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
之所以使用到tfrecords格式是因为当今数据爆炸的情况下,使用普通的数据格式不仅麻烦,而且速度慢,这种专门为tensorflow定制的数据格式可以大大增快数据的读取,而且将所有内容规整,在保证速度的情况下,使得数据更加简单明晰。
这个例子将会讲述如何将MNIST数据集写入到tfrecords,本次用到的MNIST数据集会利用tensorflow原有的库进行导入。
from tensorflow.examples.tutorials.mnist import input_data
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)
对于MNIST数据集而言,其中的训练集是mnist.train,而它的数据可以分为images和labels,可通过如下方式获得。
# 获得image,shape为(55000,784) images = mnist.train.images # 获得label,shape为(55000,10) labels = mnist.train.labels # 获得一共具有多少张图片 num_examples = mnist.train.num_examples
接下来定义存储TFRecord文件的地址,同时创建一个writer来写TFRecord文件。
# 存储TFRecord文件的地址 filename = 'record/output.tfrecords' # 创建一个writer来写TFRecord文件 writer = tf.python_io.TFRecordWriter(filename)
此时便可以按照一定的格式写入了,此时需要对每一张图片进行循环并写入,在tf.train.Features中利用features字典定义了数据保存的方式。以image_raw为例,其经过函数_float_feature处理后,存储到tfrecords文件的’image/encoded’位置上。
# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
image_raw = images[i] # 读取每一幅图像
image_string = images[i].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/class/label': _int64_feature(np.argmax(labels[i])),
'image/encoded': _float_feature(image_raw),
'image/encoded_tostring': _bytes_feature(image_string)
}
)
)
print(i,"/",num_examples)
writer.write(example.SerializeToString()) # 将Example写入TFRecord文件
在最终存入前,数据还需要经过处理,处理方式如下:
# 生成整数的属性
def _int64_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
tfrecords的读取首先要创建一个reader来读取TFRecord文件中的Example。
# 创建一个reader来读取TFRecord文件中的Example reader = tf.TFRecordReader()
再创建一个队列来维护输入文件列表。
# 创建一个队列来维护输入文件列表 filename_queue = tf.train.string_input_producer(['record/output.tfrecords'])
利用reader读取输入文件列表队列,并用parse_single_example将读入的Example解析成tensor
# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
serialized_example,
features={
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
}
)
此时我们得到了一个features,实际上它是一个类似于字典的东西,我们额可以通过字典的方式读取它内部的内容,而字典的索引就是我们再写入tfrecord文件时所用的feature。
# 将字符串解析成图像对应的像素数组 labels = tf.cast(features['image/class/label'], tf.int32) images = tf.cast(features['image/encoded'], tf.float32) images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)
最后利用一个循环输出:
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
label, image = sess.run([labels, images])
images_tostring = sess.run(images_tostrings)
print(np.shape(image))
print(np.shape(images_tostring))
print(label)
print("#########################")
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 生成整数的属性
def _int64_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
if not isinstance(value,list) and not isinstance(value,np.ndarray):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)
# 获得image,shape为(55000,784)
images = mnist.train.images
# 获得label,shape为(55000,10)
labels = mnist.train.labels
# 获得一共具有多少张图片
num_examples = mnist.train.num_examples
# 存储TFRecord文件的地址
filename = 'record/Mnist_Out.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
image_raw = images[i] # 读取每一幅图像
image_string = images[i].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/class/label': _int64_feature(np.argmax(labels[i])),
'image/encoded': _float_feature(image_raw),
'image/encoded_tostring': _bytes_feature(image_string)
}
)
)
print(i,"/",num_examples)
writer.write(example.SerializeToString()) # 将Example写入TFRecord文件
print('data processing success')
writer.close()
运行结果为:
……
54993 / 55000
54994 / 55000
54995 / 55000
54996 / 55000
54997 / 55000
54998 / 55000
54999 / 55000
data processing success
import tensorflow as tf
import numpy as np
# 创建一个reader来读取TFRecord文件中的Example
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['record/Mnist_Out.tfrecords'])
# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
serialized_example,
features={
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
}
)
# 将字符串解析成图像对应的像素数组
labels = tf.cast(features['image/class/label'], tf.int32)
images = tf.cast(features['image/encoded'], tf.float32)
images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)
sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
label, image = sess.run([labels, images])
images_tostring = sess.run(images_tostrings)
print(np.shape(image))
print(np.shape(images_tostring))
print(label)
print("#########################")
运行结果为:
#########################
(784,)
(784,)
7
#########################
#########################
(784,)
(784,)
4
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
9
#########################
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
本篇文章给大家带来了关于Python的相关知识,详细介绍了Python实现提取四种不同文本特征的方法,有字典文本特征提取、英文文本特征提取、中文文本特征提取和TF-IDF 文本特征提取,感兴趣的可以了解一下。
内容介绍chr()函数与ord()函数解析chr()函数ord()函数应用:凯撒密码的加密和解码ord()函数与chr()函数的区别chr()函数与ord()函数解析chr()函数用一个范围在ran
大部分程序和语言中的随机数,其实都只是伪随机。是由可确定的函数(常用线性同余),通过一个种子(常用时钟)产生的。直观来想,计算机就是一种确定的、可预测的的设备:一行行的代码是固定的,一步步的算法是固定的,一个个与非门是固定的。
Matplotlib是Python中最受欢迎的数据可视化软件包之一,它是 Python常用的2D绘图库,同时它也提供了一部分3D绘图接口。本文将详细介绍Matplotlib的绘图方式,需要的可以参考一下
Python中字符串反转常用的五种方法:使用字符串切片、使用递归、使用列表reverse()方法、使用栈和使用for循环。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008