BN和dropout是什么,使用有哪些不同
Admin 2022-09-05 群英技术资讯 878 次浏览
本篇内容介绍了“BN和dropout是什么,使用有哪些不同”的有关知识,在实际项目的操作过程或是学习过程中,不少人都会遇到这样的问题,接下来就让小编带大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!BN在训练时是在每个batch上计算均值和方差来进行归一化,每个batch的样本量都不大,所以每次计算出来的均值和方差就存在差异。预测时一般传入一个样本,所以不存在归一化,其次哪怕是预测一个batch,但batch计算出来的均值和方差是偏离总体样本的,所以通常是通过滑动平均结合训练时所有batch的均值和方差来得到一个总体均值和方差。
以tensorflow代码实现为例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 计算当前整个batch的均值与方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑动平均更新均值与方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后执行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training参数可以通过tf.placeholder传入,这样就可以控制训练和预测时training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout在训练时会随机丢弃一些神经元,这样会导致输出的结果变小。而预测时往往关闭dropout,保证预测结果的一致性(不关闭dropout可能同一个输入会得到不同的输出,不过输出会服从某一分布。另外有些情况下可以不关闭dropout,比如文本生成下,不关闭会增大输出的多样性)。
为了对齐Dropout训练和预测的结果,通常有两种做法,假设dropout rate = 0.2。一种是训练时不做处理,预测时输出乘以(1 - dropout rate)。另一种是训练时留下的神经元除以(1 - dropout rate),预测时不做处理。以tensorflow为例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二种做法,训练时除以(1 - dropout rate),源码如下:
binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
binary_tensor就是一个mask tensor,即里面的值由0或1组成。keep_prob = 1 - dropout rate。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍Matplotlib绘制子图的方式,常用的方式有通过plt的subplot、通过figure的add_subplot和通过plt的subplots,下面我们就来看看怎样绘制子图吧,感兴趣的朋友可以参考。
今天带大家复习python基础知识,此文章将要介绍如何组织文件,既拷贝,移动等,文中有非常详细的代码示例,对正在学习python的小伙伴们很有帮助,需要的朋友可以参考下
这篇文章主要介绍了pytorch部署到jupyter中,在这里需要注意我再输入的时候出现了一些无法定位的提示,但是我的电脑没有影响使用jupyter,还是可以使用jupyter并且可以import torch,本文给大家讲解的非常详细,需要的朋友参考下吧
json库是处理JSON格式的Python标准库,json库主要包括两类函数,操作函数和解析函数,下面这篇文章主要给大家介绍了关于python标准库模块之json库的基础用法,需要的朋友可以参考下
这篇文章主要介绍了Python 文件操作方法总结,文章基于python的基础展开Python 文件操作方法,具有一定的参考价值,需要的小伙伴可以参考一下
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008