图像分割实战-系列教程9:U2NET显著性检测实战1

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

U2NET显著性检测实战1
U2NET显著性检测实战2
U2NET显著性检测实战3

1、任务概述

所谓(Salient Objection Detection,SOD )显著性检测,就是把所有的全景都识别出来,把背景全部去除掉

输入是一张正常的图像:
  • 输出的是一种素描画像的图像
  • 扣掉前景的图像
  • 只有目标前景的图像

可以用来扣绿幕,扣前景等

2、任务数据集介绍

左边就是训练数据,有多种类的图像

右边就是标签,将前景扣出来的图像

3、U2Net概述

  1. 图a,一般的网络都是卷积Conv、批归一化BatchNormalization、ReLU三连,即CBR
  2. 图b,后来有了ResNet,可以保证不断加深网络不会出现性能下降的情况
  3. 图c、图d,unetunet++网络网络主要是在特征拼接上做出了改进提取出更好的特征,获得更好的效果
  4. 图e,即本文的网络U2Net,即采用了ResNet的残差连接,又采用了特征拼接,有更好的效果

Unet的U没变,但是为什么叫U2Net呢,这里的2是平方的意思,实际上就是在一些小细节方面又做了一次Unet结构,每一个小模块不在像之前的Unet使用VGG来做backbones,而是每一个backbones都使用了Unet,即U2Net,666哈哈哈

4、U2Net网络结构源码解读

在这里插入图片描述

4.1 构造函数

class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6,out_ch,1)
  1. 首先解释一下RSU7、RSU6这些,表示的是残差Unet,这是一个残差连接的Unet模块
  2. 编码器部分,定义了5个的(残差Unet+Maxpooling),再加一次没有Maxpooling的残差Unet
  3. 每次残差Unet都会传入,输入通道,中间通道,输出通道3个参数
  4. 解码器部分,定义了对应的5个残差Unet
  5. 最后定义了6个二维卷积,和一个最后输出的二维卷积

4.2 前向传播

4.2.1 encoder

    def forward(self,x):
        hx = x
        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)
        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)
        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)
        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)
        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)
        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

编码器部分,按照顺序走完5个残差Unet+Maxpooling,一个残差Unet+上采样

4.2.1 decoder

        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)
        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
        
        d1 = self.side1(hx1d)
        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)
        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)
        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)
        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)
        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)
        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

解码器部分,按顺序走完4次(对应Unet部分的特征拼接+残差Unet+上采样),加上一次没有上采样的(对应Unet部分的特征拼接+残差Unet)
然后在对应部分输出的地方经过一次二维卷积和一次上采样操作
最后每个部分对应的输出都会经过sigmoid激活函数

每个位置的对应shape值输出:
输入数据:torch.Size([6, 3, 288, 288])
stage 1
torch.Size([6, 64, 288, 288])
torch.Size([6, 64, 144, 144])

stage 2
torch.Size([6, 128, 144, 144])
torch.Size([6, 128, 72, 72])

stage 3
torch.Size([6, 256, 72, 72])
torch.Size([6, 256, 36, 36])

stage 4
torch.Size([6, 512, 36, 36])
torch.Size([6, 512, 18, 18])

stage 5
torch.Size([6, 512, 18, 18])
torch.Size([6, 512, 9, 9])

stage 6
torch.Size([6, 512, 9, 9])
torch.Size([6, 512, 18, 18])

stage5d
torch.Size([6, 512, 18, 18])
torch.Size([6, 512, 36, 36])

stage4d
torch.Size([6, 256, 36, 36])
torch.Size([6, 256, 72, 72])

stage3d
torch.Size([6, 128, 72, 72])
torch.Size([6, 128, 144, 144])

stage2d
torch.Size([6, 64, 144, 144])
torch.Size([6, 64, 288, 288])

stage1d
torch.Size([6, 64, 288, 288])

side1
torch.Size([6, 1, 288, 288])

side2
torch.Size([6, 1, 144, 144])
torch.Size([6, 1, 288, 288])

side3
torch.Size([6, 1, 72, 72])
torch.Size([6, 1, 288, 288])

side4
torch.Size([6, 1, 36, 36])
torch.Size([6, 1, 288, 288])

side5
torch.Size([6, 1, 18, 18])
torch.Size([6, 1, 288, 288])

side6
torch.Size([6, 1, 9, 9])
torch.Size([6, 1, 288, 288])

输出
torch.Size([6, 1, 288, 288])

U2NET显著性检测实战1
U2NET显著性检测实战2
U2NET显著性检测实战3


http://www.niftyadmin.cn/n/5306850.html

相关文章

神经网络-卷积层

卷积 输入通道数, 输出通道数,核大小 参数具体含义 直观理解各个参数的网站(gif) https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md大概长这样,cyan是青色的意思 channel数(终于理解论文里图片放好多层的原因…

Windows 截图工具①FastStone Capture ②PixPin

​ 今天说说Windows好用的载图软件: 2023年12月21日,更新新增 PixPin 2款超级好用的截图软件,不止是载图、编辑、录屏、Gif等好用的功能,本人一直也在用。 ①FastStone Capture ②PixPin 软件介绍:FastStone Fast…

Debezium发布历史50

原文地址: https://debezium.io/blog/2019/02/25/debezium-0-9-2-final-released/ 欢迎关注留言,我是收集整理小能手,工具翻译,仅供参考,笔芯笔芯. Debezium 0.9.2.Final 发布 二月 25, 2019 作者: Gunna…

计算机创新协会冬令营——暴力枚举题目06

我给大家第一阶段的最后一道题就到这里了,下次得过段时间了。所以这道题简单一点。但是足够经典 下述题目描述和示例均来自力扣:两数之和 题目描述 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target …

Django发送QQ邮件

创建一个表单,供用户填写他们的姓名和电子邮件、电子邮件收件人和可选的注释 创建blog/forms.py from django import formsclass EmailPostForm(forms.Form):name forms.CharField(max_length25)email forms.EmailField()to forms.EmailField()comments forms.…

MySQL-数据库概述

数据库相关概念: 数据库(DateBase)简称DB,就是一个存储数据的仓库,数据有组织的进行存储。 数据库分为关系型数据库简称RDBMS和非关系型数据库 关系型数据库简称RDBMS:建立在关系模型的基础上,由多张相互连接的二维表组成的数据库.简单来说…

【c语言】指针小结

一、指针是什么? 可以通过运算符&来取得变量实际保存的 起始地址 。 (这个地址是虚拟地址,并不是真正物理内存上的地址。) 数据类型 *标识符 &变量; int *pa &a; int *pa NULL; (NULL表示地址为0的内存空间&a…

Spring Boot依赖版本声明

链接 官网 Spring Boot文档官网:​​​​​​https://docs.spring.io/spring-boot/docs/https://docs.spring.io/spring-boot/docs/ Spring Boot 2.0.7.RELEASE Spring Boot 2.0.7.RELEASE reference相关:https://docs.spring.io/spring-boot/docs/2.…