🍁🍁🍁图像分割实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
U2NET显著性检测实战1
U2NET显著性检测实战2
U2NET显著性检测实战3
1、任务概述
所谓(Salient Objection Detection,SOD )显著性检测,就是把所有的全景都识别出来,把背景全部去除掉
- 输出的是一种素描画像的图像
- 扣掉前景的图像
- 只有目标前景的图像
可以用来扣绿幕,扣前景等
2、任务数据集介绍
左边就是训练数据,有多种类的图像
右边就是标签,将前景扣出来的图像
3、U2Net概述
- 图a,一般的网络都是卷积Conv、批归一化BatchNormalization、ReLU三连,即CBR
- 图b,后来有了ResNet,可以保证不断加深网络不会出现性能下降的情况
- 图c、图d,unet、unet++网络网络主要是在特征拼接上做出了改进提取出更好的特征,获得更好的效果
- 图e,即本文的网络U2Net,即采用了ResNet的残差连接,又采用了特征拼接,有更好的效果
Unet的U没变,但是为什么叫U2Net呢,这里的2是平方的意思,实际上就是在一些小细节方面又做了一次Unet结构,每一个小模块不在像之前的Unet使用VGG来做backbones,而是每一个backbones都使用了Unet,即U2Net,666哈哈哈
4、U2Net网络结构源码解读
4.1 构造函数
class U2NET(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NET,self).__init__()
self.stage1 = RSU7(in_ch,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
self.outconv = nn.Conv2d(6,out_ch,1)
- 首先解释一下RSU7、RSU6这些,表示的是残差Unet,这是一个残差连接的Unet模块
- 编码器部分,定义了5个的(残差Unet+Maxpooling),再加一次没有Maxpooling的残差Unet
- 每次残差Unet都会传入,输入通道,中间通道,输出通道3个参数
- 解码器部分,定义了对应的5个残差Unet
- 最后定义了6个二维卷积,和一个最后输出的二维卷积
4.2 前向传播
4.2.1 encoder
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
编码器部分,按照顺序走完5个残差Unet+Maxpooling,一个残差Unet+上采样
4.2.1 decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
解码器部分,按顺序走完4次(对应Unet部分的特征拼接+残差Unet+上采样),加上一次没有上采样的(对应Unet部分的特征拼接+残差Unet)
然后在对应部分输出的地方经过一次二维卷积和一次上采样操作
最后每个部分对应的输出都会经过sigmoid激活函数
每个位置的对应shape值输出:
输入数据:torch.Size([6, 3, 288, 288])
stage 1
torch.Size([6, 64, 288, 288])
torch.Size([6, 64, 144, 144])
stage 2
torch.Size([6, 128, 144, 144])
torch.Size([6, 128, 72, 72])
stage 3
torch.Size([6, 256, 72, 72])
torch.Size([6, 256, 36, 36])
stage 4
torch.Size([6, 512, 36, 36])
torch.Size([6, 512, 18, 18])
stage 5
torch.Size([6, 512, 18, 18])
torch.Size([6, 512, 9, 9])
stage 6
torch.Size([6, 512, 9, 9])
torch.Size([6, 512, 18, 18])
stage5d
torch.Size([6, 512, 18, 18])
torch.Size([6, 512, 36, 36])
stage4d
torch.Size([6, 256, 36, 36])
torch.Size([6, 256, 72, 72])
stage3d
torch.Size([6, 128, 72, 72])
torch.Size([6, 128, 144, 144])
stage2d
torch.Size([6, 64, 144, 144])
torch.Size([6, 64, 288, 288])
stage1d
torch.Size([6, 64, 288, 288])
side1
torch.Size([6, 1, 288, 288])
side2
torch.Size([6, 1, 144, 144])
torch.Size([6, 1, 288, 288])
side3
torch.Size([6, 1, 72, 72])
torch.Size([6, 1, 288, 288])
side4
torch.Size([6, 1, 36, 36])
torch.Size([6, 1, 288, 288])
side5
torch.Size([6, 1, 18, 18])
torch.Size([6, 1, 288, 288])
side6
torch.Size([6, 1, 9, 9])
torch.Size([6, 1, 288, 288])
输出
torch.Size([6, 1, 288, 288])
U2NET显著性检测实战1
U2NET显著性检测实战2
U2NET显著性检测实战3