RNN的tf.nn.dynamic_rnn定义和用法是什么
Admin 2022-09-16 群英技术资讯 757 次浏览
在实际应用中,我们有时候会遇到“RNN的tf.nn.dynamic_rnn定义和用法是什么”这样的问题,我们该怎样来处理呢?下文给大家介绍了解决方法,希望这篇“RNN的tf.nn.dynamic_rnn定义和用法是什么”文章能帮助大家解决问题。已经完成了RNN网络的构建,但是我们对于RNN网络还有许多疑问,特别是tf.nn.dynamic_rnn函数,其具体的应用方式我们并不熟悉,查询了一下资料,我心里的想法是这样的。
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
其返回值为outputs,states。
outputs:RNN的最后一层的输出,是一个tensor。如果为time_major== False,则它的shape为[batch_size,max_time,cell.output_size]。如果为time_major== True,则它的shape为[max_time,batch_size,cell.output_size]。
states:是每一层的最后一个step的输出,是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下states的形状为 [batch_size, cell.output_size],但当输入的cell为BasicLSTMCell时,states的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。
我们首先使用单层的RNN进行实验。
使用的代码为:
import tensorflow as tf
import numpy as np
n_steps = 2 #两个step
n_inputs = 3 #每个input是三维
n_nerve = 4 #神经元个数
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
输出的log为:
outputs: [[[0.92146313 0.6069534 0.24989243 0.9305415 ] [0.9234855 0.8470011 0.7865616 0.99935764]] [[0.9772771 0.9713368 0.99483156 0.9999987 ] [0.9753329 0.99538314 0.9988139 1. ]] [[0.9901842 0.99558043 0.9998626 1. ] [0.989398 0.9992842 0.9999691 1. ]] [[0.99577546 0.9993256 0.99999636 1. ] [0.9954579 0.9998903 0.99999917 1. ]]] states: [[0.9234855 0.8470011 0.7865616 0.99935764] [0.9753329 0.99538314 0.9988139 1. ] [0.989398 0.9992842 0.9999691 1. ] [0.9954579 0.9998903 0.99999917 1. ]]
在time_major = False的时候:
接下来我们使用两层的RNN进行实验。
使用的代码为:
import tensorflow as tf
import numpy as np
n_steps = 2 #两个step
n_inputs = 3 #每个input是三维
n_nerve = 4 #神经元个数
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
#定义多层
layers = [tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve) for i in range(2)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
输出的log为:
outputs: [[[-0.577939 -0.3657474 -0.21074213 0.8188577 ]
[-0.67090076 -0.47001836 -0.40080917 0.6026697 ]]
[[-0.72777444 -0.36500326 -0.7526911 0.86113644]
[-0.7928404 -0.6413429 -0.61007065 0.787065 ]]
[[-0.7537433 -0.35850585 -0.83090436 0.8573037 ]
[-0.82016116 -0.6559162 -0.7360482 0.7915131 ]]
[[-0.7597004 -0.35760364 -0.8450942 0.8567379 ]
[-0.8276395 -0.6573326 -0.7727142 0.7895221 ]]]
states: (array([[-0.71645427, -0.0585744 , 0.95318353, 0.8424729 ],
[-0.99845 , -0.5044571 , 0.9955299 , 0.9750488 ],
[-0.99992913, -0.8408632 , 0.99885863, 0.9932366 ],
[-0.99999577, -0.9672 , 0.9996866 , 0.99814796]],
dtype=float32),
array([[-0.67090076, -0.47001836, -0.40080917, 0.6026697 ],
[-0.7928404 , -0.6413429 , -0.61007065, 0.787065 ],
[-0.82016116, -0.6559162 , -0.7360482 , 0.7915131 ],
[-0.8276395 , -0.6573326 , -0.7727142 , 0.7895221 ]],
dtype=float32))
可以看出来outputs对应的是RNN的最后一层的输出,states对应的是每一层的最后一个step的输出,在完成了两层的定义后,outputs的shape并没有变化,而states的内容多了一层,分别对应RNN的两层输出。
state中最后一层输出对应着outputs最后一步的输出。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
Cookie用于服务器实现会话,用户登录及相关功能时进行状态管理。要在用户浏览器上安装cookie,HTTP服务器向HTTP响应添加类似以下内容的HTTP
for循环用于迭代序列(即列表、元组、字典、集合或字符串)。for 语句的写法如从对象开始按顺序给变量赋值,元素个数重复这个过程。对象可以是列表(数组)、元组、字符串等。本文将详细讲解Python中for定义迭代方法详解,需要的可以了解一下
在写Python的时候经常会遇到时间格式的问题,首先就是最近用到的时间戳(timestamp)和时间字符串之间的转换。所谓时间戳,就是从 1970年1
这篇文章主要介绍了Python文件及目录处理的方法,Python可以用于处理文本文件和二进制文件,比如创建文件、读写文件等操作。本文介绍Python处理目录以及文件的相关资料,需要的朋友可以参考一下
这篇文章主要介绍了Django表单外键选项初始化的问题及解决方法,需本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,要的朋友可以参考下
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008