pytorch大数据加载memory不足够的问题如何解决
Admin 2022-07-23 群英技术资讯 1654 次浏览
今天这篇我们来学习和了解“pytorch大数据加载memory不足够的问题如何解决”,下文的讲解详细,步骤过程清晰,对大家进一步学习和理解“pytorch大数据加载memory不足够的问题如何解决”有一定的帮助。有这方面学习需要的朋友就继续往下看吧!最近用pytorch做实验时,遇到加载大量数据的问题。实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory。
首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可。
MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所有数据加载到内存中了。训练数据存放在train.txt中,每一行是一条数据记录。
import torch.utils.data as Data from tqdm import tqdm class MyDataset(Data.Dataset): def __init__(self,filepath): number = 0 with open(filepath,"r") as f: # 获得训练数据的总行数 for _ in tqdm(f,desc="load training dataset"): number+=1 self.number = number self.fopen = open(filepath,'r') def __len__(self): return self.number def __getitem__(self,index): line = self.fopen.__next__() # 自定义transform()对训练数据进行预处理 data = transform(line) return data train_dataset = MyDataset(filepath = "train.txt") training_data = Data.DataLoader(dataset=train_dataset, batch_size=32,num_workers=1)
1、num_workers只能设置为1。因为MyDataset初始化时只有一个文件对象,在dataloader时num_workers=1只用一个线程去操作文件对象读取数据。如果num_workers>1, 会出错,多个线程同时操作同一个文件对象,得到的数据并不是你想要的。
2、每一个epoch结束以后,需要重新声明train_dataset和training_data。因为一个epoch结束以后,文件对象已经指向文件末尾,下一个epoch取数据时,什么也得不到。
3、因为这里__getitem__()只是顺序的从文件中取出一行,而与index无关,那么在DataLoader时,即使参数shuffle指定为True,得到的数据依然是顺序的,即该方法无法shuffle数据。
补充:Pytorch加载自己的数据集(使用DataLoader读取Dataset)
很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader
Dataset:是被封装进DataLoader里,实现该方法封装自己的数据和标签。
DataLoader:被封装入DataLoaderIter里,实现该方法达到数据的划分。
阅读源码后,我们可以指导,继承该方法必须实现两个方法:
_getitem_()
_len_()
因此,在实现过程中我们测试如下:
import torch
import numpy as np
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)
提供对Dataset的操作,操作如下:
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
参数含义如下:
dataset: 加载torch.utils.data.Dataset对象数据
batch_size: 每个batch的大小
shuffle:是否对数据进行打乱
drop_last:是否对无法整除的最后一个datasize进行丢弃
num_workers:表示加载的时候子进程数
因此,在实现过程中我们测试如下(紧跟上述用例):
from torch.utils.data import DataLoader # 读取数据 datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。
我们可以通过迭代器(enumerate)进行输出数据,测试如下:
for i, data in enumerate(datas):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch \n{}".format(i, data))
输出结果如下图:

结果说明:由于数据的是10个,batchsize大小为6,且drop_last=False,因此第一个大小为6,第二个为4。每一个batch中包含data和对应的labels。
当我们想取出data和对应的labels时候,只需要用下表就可以啦,测试如下:
# 表示输出数据 print(data[0]) # 表示输出标签 print(data[1])
结果如图:

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要给大家分享python做三维图可视化的内容,这是学习tensorflow框架中遇到的知识,下文会使用到的定义函数选用的是将x和y封装,方便tensorflow求导,下面我们就来具体一些实现代码以及要注意的问题。
列表 List列表是有序的列表可以包含任意对象通过索引访问列表元素列表嵌套列表可变元组 Tuple定义和使用元组元素对比列表的优点元组分配、打包和解包List 与 Tuple 的
这篇文章主要介绍了python优雅实现代码与敏感信息分离的方法,在flask中,python-dotenv 可以无缝接入项目中,只要你的项目中存在 .env 或者 .flaskenv 文件,他就会提示你是否安装 python-dotenv,需要的朋友可以参考下
property是一个类,可以把一个方法当做属性进行使用,这样做可以简化代码使用。实际上就是装饰类中属性的gettersetter方法,使得属性可以通过对象.属性的方式获取或设置 使用property的两种方式装饰器方式类属性方式2.装饰器方式@property修饰获取的方法getter,方法名必须和属性名一样@age.setter修饰设置值的方法sett
Python基础-函数(三)<1> 对返回的数据直接拆包。<2> 交换2个变量的值
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008