LSTM是什么,在Keras中如何实现LSTM
Admin 2022-09-17 群英技术资讯 909 次浏览
这篇文章主要介绍“LSTM是什么,在Keras中如何实现LSTM”的相关知识,下面会通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“LSTM是什么,在Keras中如何实现LSTM”文章能帮助大家解决问题。
我们可以看出,在n时刻,LSTM的输入有三个:
LSTM的输出有两个:
LSTM用两个门来控制单元状态cn的内容:
LSTM用一个门来控制当前输出值hn的内容:
输出门(output gate),它利用当前时刻单元状态cn对hn的输出进行控制。


遗忘门这里需要结合ht-1和Xt来决定上一时刻的单元状态cn-1有多少保留到当前时刻;
由图我们可以得到,我们在这一环节需要计一个参数ft。



输入门这里需要结合ht-1和Xt来决定当前时刻网络的输入c’n有多少保存到单元状态cn中。
由图我们可以得到,我们在这一环节需要计算两个参数,分别是it。

和C’t

里面需要训练的参数分别是Wi、bi、WC和bC。
在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。
所以:


输出门利用当前时刻单元状态cn对hn的输出进行控制;
由图我们可以得到,我们在这一环节需要计一个参数ot。

里面需要训练的参数分别是Wo和bo。在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。所以:

所以所有的门总参数量为:

LSTM一般需要输入两个参数。
一个是unit、一个是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神经元的数量。
input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import LSTM
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
print("accuracy:",accuracy)
实现效果:
10000/10000 [==============================] - 3s 340us/step accuracy: 0.14040000014007092 10000/10000 [==============================] - 3s 310us/step accuracy: 0.6507000041007995 10000/10000 [==============================] - 3s 320us/step accuracy: 0.7740999992191792 10000/10000 [==============================] - 3s 305us/step accuracy: 0.8516999959945679 10000/10000 [==============================] - 3s 322us/step accuracy: 0.8669999945163727 10000/10000 [==============================] - 3s 324us/step accuracy: 0.889699995815754 10000/10000 [==============================] - 3s 307us/step
关于“LSTM是什么,在Keras中如何实现LSTM”的内容今天就到这,感谢各位的阅读,大家可以动手实际看看,对大家加深理解更有帮助哦。如果想了解更多相关内容的文章,关注我们,群英网络小编每天都会为大家更新不同的知识。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了python 基础绘图之关于随时间序列变动的图的画法,首先大家要明白画图需要考虑的问题,如何在图中适当的显示轴标签的样式和数量,详细代码跟随小编一起看看吧
这篇文章主要给大家分享Python字符串对齐的方法,包括ljust()、rjust() 和 center() 这三种方法都可以来文本对齐,感兴趣的朋友可以参考一下,希望大家阅读完这篇文章能有所收获,下面我们一起来学习一下吧。
这篇文章主要介绍了Python连接数据库使用matplotlib画柱形图,文章通过实例展开对主题的相关介绍。具有一定的知识参考价值性,感兴趣的小伙伴可以参考一下
这篇文章主要为大家详细介绍了PyQt中实现自定义工具提示ToolTip的方法详解,文中的示例代码讲解详细,对我们学习有一定帮助,需要的可以参考一下
这篇文章主要介绍了了Python如何实现将多张图片合成一张马赛克图片。文中的示例代码讲解详细,对我们学习Python有一定的帮助,感兴趣的可以学习一下
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008