pytorch模型保存与加载的实例中有哪些问题
Admin 2022-11-22 群英技术资讯 1212 次浏览
在这篇文章中,我们来学习一下“pytorch模型保存与加载的实例中有哪些问题”的相关知识,下文有详细的讲解,易于大家学习和理解,有需要的朋友可以借鉴参考,下面就请大家跟着小编的思路一起来学习一下吧。torch.save(model,path)
torch.load(path)
登录后复制
torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
登录后复制
模型保存的时候会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析它然后装载参数;当把模型定义文件路径修改以后,使用torch.load(path)就会报错。



把model文件夹修改为models后,再加载就会报错。
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)登录后复制
这种保存完整模型结构和参数的方式,一定不要改动模型定义文件路径。
在多卡机器上有多张显卡0号开始,现在模型在n>=1上的显卡训练保存后,拷贝在单卡机器上加载
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)登录后复制

会出现cuda device不匹配的问题——你保存的模代码段 小部件型是使用的cuda1,那么采用torch.load()打开的时候,会默认的去寻找cuda1,然后把模型加载到该设备上。这个时候可以直接使用map_location来解决,把模型加载到CPU上即可。
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))登录后复制
当用多GPU同时训练模型之后,不管是采用模型结构和参数一起保存还是单独保存模型参数,然后在单卡下加载都会出现问题
a、模型结构和参数一起保然后在加载

torch.distributed.init_process_group(backend='nccl')
登录后复制
模型训练的时候采用上述多进程的方式,所以你在加载的时候也要声明,不然就会报错。
b、单独保存模型参数
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)登录后复制
同样会出现问题,不过这里出现的问题是参数字典的key和模型定义的key不一样

原因是多GPU训练下,使用分布式训练的时候会给模型进行一个包装,代码如下:
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)登录后复制
包装前的模型结构:

包装后的模型

在外层多了DistributedDataParallel以及module,所以才会导致在单卡环境下加载模型权重的时候出现权重的keys不一致。
if gpu_count > 1:
torch.save(model.module.state_dict(),save_path)
else:
torch.save(model.state_dict(),save_path)
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)登录后复制
这样就是比较好的范式,加载不会出错。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
大家好,本篇文章主要讲的是python绘制超炫酷动态Julia集示例,感兴趣的痛学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
大家好,本篇文章主要讲的是python二分法查找函数底值,感兴趣的同学赶快来看一看吧,对你有用的话记得收藏一下,方便下次浏览
最近处理一些规格不一的照片,需要修改成指定尺寸便于打印,下面这篇文章主要给大家介绍了关于python批处理将图片进行放大的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
在使用 Python 处理字符串的时候,有时候会需要分割字符。 分隔符比如下划线 “_”,比如 “.”之类的。一个分隔符 比如对于文件名 20191022_log.zip,我们想要获取前面的日期。 如果日期格式固定,对于这样的字符串我们当然可以使用索引进行切割。 当...
协程:英文名Coroutine,是单线程下的并发,又称微线程,纤程。协程是一种用户态的轻量级线程,即协程是由用户程序自己控制调度的。对比操作系统控制线程的切换,用户在单线程内控制协程的切换。协程自己本身无法实现并发(甚至性能会降低),协程+IO切换性能提高。
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008