Pytorch搭建SRGAN以生成高分辨率图片怎样做
Admin 2022-09-15 群英技术资讯 818 次浏览
SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。
如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。
该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。
SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。
生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。:
SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数。
2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升。
前两个部分用于特征提取,第三部分用于提高分辨率。
import math import torch from torch import nn class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.prelu = nn.PReLU(channels) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): short_cut = x x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.conv2(x) x = self.bn2(x) return x + short_cut class UpsampleBLock(nn.Module): def __init__(self, in_channels, up_scale): super(UpsampleBLock, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) self.pixel_shuffle = nn.PixelShuffle(up_scale) self.prelu = nn.PReLU(in_channels) def forward(self, x): x = self.conv(x) x = self.pixel_shuffle(x) x = self.prelu(x) return x class Generator(nn.Module): def __init__(self, scale_factor, num_residual=16): upsample_block_num = int(math.log(scale_factor, 2)) super(Generator, self).__init__() self.block_in = nn.Sequential( nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU(64) ) self.blocks = [] for _ in range(num_residual): self.blocks.append(ResidualBlock(64)) self.blocks = nn.Sequential(*self.blocks) self.block_out = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64) ) self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) self.upsample = nn.Sequential(*self.upsample) def forward(self, x): x = self.block_in(x) short_cut = x x = self.blocks(x) x = self.block_out(x) upsample = self.upsample(x + short_cut) return torch.tanh(upsample)
判别网络的构成如上图所示:
SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。
判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。
判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。
实现代码如下:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1) ) def forward(self, x): batch_size = x.size(0) return torch.sigmoid(self.net(x).view(batch_size))
SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。
在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。
因此判别器的训练步骤如下:
1、随机选取batch_size个真实高分辨率图片。
2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。
在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。
因此生成器的训练步骤如下:
1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss
SRGAN的库整体结构如下:
在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在完成数据集处理后,运行train.py即可开始训练。
训练过程中,可在results文件夹内查看训练效果:
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
今天给大家带来的是关于Python的相关知识,文章围绕着如何使用Python脚本实现自动登录校园网展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
使用Python我们可以轻松地将数据转换成不同的类型,下面这篇文章主要给大家介绍了关于Python容器类型转换的3种方法,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
这篇文章主要介绍了Python小技巧练习分享,文章基于python的相关内容展开详细的python小技巧内容,具有一定的参考价值,需要的小伙伴可以参考一下
这篇文章主要为大家介绍了如何通过Python实现一个非常精简的图像化的PDF区域选择提取工具,文中示例代码讲解详细,感兴趣的小伙伴可以学习一下
这篇文章主要介绍了Pytorch中使用TensorBoard详情,TensorBoard的前段数据显示和后端数据记录是异步I/O的,即后端程序将数据写入到一个文件中,而前端程序读取文件中的数据来进行显示
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008