如何理解python神经网络tf.train.batch函数的使用
Admin 2022-09-16 群英技术资讯 564 次浏览
tf.train.batch( tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None )
其中:
1、tensors:利用slice_input_producer获得的数据组合。
2、batch_size:设置每次从队列中获取出队数据的数量。
3、num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。
4、capacity:一个整数,用来设置队列中元素的最大数量
5、allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。
6、name:名字
import pandas as pd import numpy as np import tensorflow as tf # 生成数据 def generate_data(): num = 18 label = np.arange(num) return label # 获取数据 def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True) return label_batch # 数据组 label = get_batch_data() sess = tf.Session() # 初始化变量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch训练的参数 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): # 自动获取下一组数据 l = sess.run(label) print(l) except tf.errors.OutOfRangeError: print('Done training') finally: coord.request_stop() coord.join(threads) sess.close()
运行结果为:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
[17]
Done training
相比allow_samller_final_batch=True,输出结果少了[17]
import pandas as pd import numpy as np import tensorflow as tf # 生成数据 def generate_data(): num = 18 label = np.arange(num) return label # 获取数据 def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return label_batch # 数据组 label = get_batch_data() sess = tf.Session() # 初始化变量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch训练的参数 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): # 自动获取下一组数据 l = sess.run(label) print(l) except tf.errors.OutOfRangeError: print('Done training') finally: coord.request_stop() coord.join(threads) sess.close()
运行结果为:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
Done training
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
pygame是为开发2D游戏而设计的Python跨平台模块,开发人员利用pygame模块中定义的接口,可以方便快捷地实现诸如图形用户界面创建、图形和图像的绘制、用户键盘和鼠标操作的监听以及播放音频等游戏中常用的功能。
这篇文章主要为大家介绍了基于Python如何实现评论区抽奖的功能,文章的示例代码讲解详细,对我们学习Python有一定帮助,需要的朋友可以学习一下
这篇文章主要和大家分享一个有意思的模型:RealBasicVSR。本文将利用这个模型制作一个图像超分处理工具,感兴趣的小伙伴可以跟随小编一起学习一下
这篇文章主要介绍了python中的h5py开源库的使用,本文只是简单的对h5py库的基本创建文件,数据集和读取数据的方式进行介绍,需要的朋友可以参考下
#!/use/bin/envpython#-*-conding:utf-8-*-#author:shanshan"""写代码1,实现用户输入用户名和密码,当用户名为seven且密码为123时,显示登陆成功,否则登陆失败!2,实现用户输入用户名和密码,当用户名为seven且密码为123时,显示登陆成功,否则登陆失败,失败时允许重复输入三次3,实现用户输入用户名和密码,当用户
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008