pytorch如何实现图片数据准备,怎么做
Admin 2022-08-05 群英技术资讯 1015 次浏览
很多朋友都对“pytorch如何实现图片数据准备,怎么做”的内容比较感兴趣,对此小编整理了相关的知识分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获,那么感兴趣的朋友就继续往下看吧!图片数据一般有两种情况:
1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。
2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。
针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:
这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。
先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:
import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test))
f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()
经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

前期工作就装备好了,接着就进入正题了:
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。
同样先准备数据,这里以flowers数据集为例
提取 链接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6
花总共有五类,分别放在5个文件夹下。大致如下图:

我的路径是d:/flowers/.
数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
在处理数据的时候,很多时候会遇到批量替换的情况,如果一个一个去修改效率过低,也容易出错,replace()是很好的方法,下面这篇文章主要给大家介绍了关于Python pandas.replace用法的相关资料,需要的朋友可以参考下
在Python里,我们有时候会做需要多行输出的程序。例如:1、点餐系统 不停地问:你要点什么食物?2、文本编辑 不停地输入文字(仅限IDLE等Python自带编辑器 )我们Python中有一种输入语句 : input。但是,它只能单行输入所以呢,我们就要通过Python的其他语句来实现多行输入
哈希表或称为散列表,是一种常见的、使用频率非常高的数据存储方案。本文将站在开发者的角度,带着大家一起探究哈希的世界,感兴趣的小伙伴可以跟随小编一起学习一下
这篇文章主要为大家介绍了Python内建类型float源码学习,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
这篇文章主要为大家介绍了用Python和OpenCV实现的六种常见图像特效:图像融合、灰度处理、马赛克效果、浮雕效果、毛玻璃效果和颜色反转,需要的可以参考一下
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008