Pytorch框架怎样实现病虫害图像分类,代码是什么
Admin 2022-07-05 群英技术资讯 689 次浏览
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。
2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:
1、具有强大的GPU加速的张量计算(如NumPy)。
2、包含自动求导系统的深度神经网络。
两者之间区别很多,在本篇博客中只简单描述一部分。以图片的形式展现。
前者为机器学习的过程。
后者为深度学习的过程。
本次实验使用的是coco数据集中的植物病虫害数据集。分为训练文件Traindata和测试文件TestData.,
TrainData有9种分类,每一种分类有100张图片。
TestData有9中分类,每一种分类有10张图片。
在我下一篇博客中将数据集开源。
下面是我的数据集截图:
import torch from torch.utils.data import Dataset, DataLoader import numpy as np import matplotlib import os import cv2 from PIL import Image import torchvision.transforms as transforms import torch.optim as optim from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from Test.CNN import Net import json from Test.train_data import Mydataset,pad_image
# 构建神经网络 class Net(nn.Module):#定义网络模块 def __init__(self): super(Net, self).__init__() # 卷积,该图片有3层,6个特征,长宽均为5*5的像素点,每隔1步跳一下 self.conv1 = nn.Conv2d(3, 6, 5) #//(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) self.pool = nn.MaxPool2d(2, 2)#最大池化 #//(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) self.conv2 = nn.Conv2d(6, 16, 5)#卷积 #//(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) self.fc1 = nn.Linear(16*77*77, 120)#全连接层,图片的维度为16, #(fc1): Linear(in_features=94864, out_features=120, bias=True) self.fc2 = nn.Linear(120, 84)#全连接层,输入120个特征输出84个特征 self.fc3 = nn.Linear(84, 7)#全连接层,输入84个特征输出7个特征 def forward(self, x): print("x.shape1: ", x.shape) x = self.pool(F.relu(self.conv1(x))) print("x.shape2: ", x.shape) x = self.pool(F.relu(self.conv2(x))) print("x.shape3: ", x.shape) x = x.view(-1, 16*77*77) print("x.shape4: ", x.shape) x = F.relu(self.fc1(x)) print("x.shape5: ", x.shape) x = F.relu(self.fc2(x)) print("x.shape6: ", x.shape) x = self.fc3(x) print("x.shape7: ", x.shape) return x
img_path = "TestData/test_data/1/Apple2 (1).jpg" #使用相对路径 image = Image.open(img_path).convert('RGB') image_pad = pad_image(image, (320, 320)) input = transform(image_pad).to(device).unsqueeze(0) output = F.softmax(net(input), 1) _, predicted = torch.max(output, 1) score = float(output[0][predicted]*100) print(class_map[predicted], " ", str(score)+" %") plt.imshow(image_pad) # 显示图片
这次搭建的网络是基于深度学习框架Lenet,并自己做了一些修改完成。最终的训练的结果LOSS接近0,ACC接近100%。但是一般的识别率不会达到这么高,该模型可能会过拟合。可采取剪枝等操作减小过拟合。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要为大家介绍了R语言条形图及分布密度图代码总结,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
搜索路径是由一系列目录名组成的,Python解释器就依次从这些目录中去寻找所引入的模块,下面这篇文章主要给大家介绍了关于python修改包导入时搜索路径的相关资料,需要的朋友可以参考下
对于Python语言来说,比较传统的数据可视化模块是Matplotlib,但它存在不够美观、静态性、不易分享等缺点,限制了Python在数据可视化方面的发展。为了解决这个问题,新型的动态可视化开源模块Plotly应运而生。本文将为大家详细介绍Plotly的用法,需要的可以参考一下
本文主要介绍了python实现自动抢课脚本的示例代码,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
通过NETCONF,网管能够用可视化的界面统一管理网络中的设备,并且安全性高、可靠性强、扩展性强。如下图所示,网管与网络中的所有交换机之间建立NETCONF会话,用户即可在网管提供的可视化界面上对网络中的所有交换机进行统一的管理,提高网络运维效率。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008