GRU是什么,在Keras中如何实现GRU
Admin 2022-09-17 群英技术资讯 1029 次浏览
很多朋友都对“GRU是什么,在Keras中如何实现GRU”的内容比较感兴趣,对此小编整理了相关的知识分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获,那么感兴趣的朋友就继续往下看吧!GRU是LSTM的一个变种。
传承了LSTM的门结构,但是将LSTM的三个门转化成两个门,分别是更新门和重置门。
下图是每个GRU单元的结构。

在n时刻,每个GRU单元的输入有两个:
输出有一个:
当前时刻GRU输出值ht;
GRU含有两个门结构,分别是:
更新门zt和重置门rt:
更新门用于控制前一时刻的状态信息被代入到当前状态的程度,更新门的值越大说明前一时刻的状态信息带入越少,这一时刻的状态信息带入越多。
重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

更新门在图中的标号为zt,需要结合ht-1和Xt来决定上一时刻的输出ht-1有多少得到保留,更新门的值越大说明前一时刻的状态信息保留越少,这一时刻的状态信息保留越多。
结合公式我们可以知道:

zt由ht-1和Xt来决定。

当更新门zt的值较大的时候,上一时刻的输出ht-1保留较少,而这一时刻的状态信息保留较多。


重置门在图中的标号为rt,需要结合ht-1和Xt来控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。
结合公式我们可以知道:

rt由ht-1和Xt来决定。

当重置门rt的值较小的时候,上一时刻的输出ht-1保留较少,说明忽略得越多。

所以所有的门总参数量为:

GRU一般需要输入两个参数。
一个是unit、一个是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神经元的数量。
input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
print("accuracy:",accuracy)
实现效果:
10000/10000 [==============================] - 2s 231us/step accuracy: 0.16749999986961484 10000/10000 [==============================] - 2s 206us/step accuracy: 0.6134000015258789 10000/10000 [==============================] - 2s 214us/step accuracy: 0.7058000019192696 10000/10000 [==============================] - 2s 209us/step accuracy: 0.797899999320507
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要为大家介绍了python中的变量,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
Prometheus是一套开源监控系统和告警为一体,由go语言(golang)开发,是监控+报警+时间序列数据库的组合。本文将介绍Python如何调用Prometheus实现数据的监控与计算,需要的可以参考一下
这篇文章主要为大家介绍了pytorch深度神经网络入门准备自己的图片数据示例过程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
冒泡排序(Bubble Sort)是一种简单的排序算法。本文将详细为大家讲讲Python实现冒泡排序算法的方法,感兴趣的小伙伴可以跟随小编一起学习一下
这篇文章主要介绍了OpenCV实战之OpenCV中的颜色空间,解计算机视觉中常用的色彩空间,并将其用于基于颜色分割。我们还将用C ++和Python共享演示代码,下文详细内容需要的小伙伴可以参考一下
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008