Pytorch识别LeNet模型怎样实现的
Admin 2022-09-06 群英技术资讯 1264 次浏览
这篇文章给大家分享的是Pytorch识别LeNet模型怎样实现的。小编觉得挺实用的,因此分享给大家做个参考,文中的介绍得很详细,而要易于理解和学习,有需要的朋友可以参考,接下来就跟随小编一起了解看看吧。
LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm
class LeNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*25,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,10))
def forward(self,x):
return self.sequential(x)
class MLP(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sequential = nn.Sequential(nn.Flatten(),
nn.Linear(28*28,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,10))
def forward(self,x):
return self.sequential(x)
epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose = transforms.Compose([transforms.ToTensor(),
])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)
model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
train_loss = 0
test_loss = 0
correct_train = 0
correct_test = 0
for index,(x,y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
predict = model(x)
L = loss(predict,y)
optimizer.zero_grad()
L.backward()
optimizer.step()
train_loss = train_loss + L
correct_train += (predict.argmax(dim=1)==y).sum()
acc_train = correct_train/(batch*len(train_loader))
with torch.no_grad():
for index,(x,y) in enumerate(test_loader):
[x,y] = [x.to(device),y.to(device)]
predict = model(x)
L1 = loss(predict,y)
test_loss = test_loss + L1
correct_test += (predict.argmax(dim=1)==y).sum()
acc_test = correct_test/(batch*len(test_loader))
print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')
epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229
找了一张图片,将其分割成只含一个数字的图片进行测试

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#图片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
for j in range(16):
test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

将所有用来泛化性测试的图片进行准确率测试:
correct = 0
i = 0
cnt = 1
for sample in test_images:
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
if(output==i):
correct+=1
if(cnt%16==0):
i+=1
cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')
如果不反色,acc_g=0.15
acc_g:0.50625
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
时间处理是我们日常开发中最最常见的需求,例如:获取当前datetime、获取当天date、获取明天 前N天、获取当天开始和结束时间(00:00:00 23:
这篇文章主要为大家介绍了Python字典的运算 ,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
这篇文章主要介绍了pytorch 实现变分自动编码器的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
Pampy是Python的一个模式匹配类库,一个只有150行的类库,该库优雅、高效值得广大Python的码农加入自己基本开发栈中。本文就来讲讲Pampy的用法,需要的可以参考一下
如果要考察某公司的牛奶产品质量,可以从100袋牛奶中抽取30袋,在随机数表中选中一数,并用向上、下、左、右不同的读法组成30个数,并按牛奶的标号进行检测,虽然麻烦,但很常用。在日常生活中,随机数起着很大的作用,所以很多人会专门去寻找随机数生成器。
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008