U2Net使用方法和实现多类别语义分割模型改造

作者的碎碎念:U2Net是用来实现SOD的语义分割,本篇论文会介绍算法内容、主要代码、使用方法,以及如何将二分类语义分割修改为多类别语义模型。如果只想知道怎么训练自己的数据集,或者如何修改网络,可以通过目录进行跳转。
欢迎点赞、评论或收藏❤️


文章目录

  • (一)相关链接
  • (二)算法内容
    • 1. 摘要
    • 2. 介绍
    • 3. 网络架构
    • 4. loss函数
    • 5. 作者实验结果
  • (三)如何训练自己的数据
    • 1. 标注
    • 2. mask图像
    • 3. 训练数据集格式
    • 4. 配置文件修改
    • 5. 训练命令
    • 6. 测试命令
  • (四)多类别语义分割
    • 1. 实现思路
    • 2. 修改方法
    • 4. 测试
    • 5. 训练测试效果

(一)相关链接

  1. 论文名称
    《U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection》
  2. github链接
    https://github.com/xuebinqin/U-2-Net
  3. paper
    https://arxiv.org/pdf/2005.09007.pdf

(二)算法内容

1. 摘要

  U²-Net是显著物体检测(salient object detection,简写SOD)的一个网络,并且现在已经是Python的抠图工具Rembg的基础算法

  • 什么是SOD?
      SOD是模拟人类视觉感知系统来定位场景中最吸引人的目标,例如人像
  • 算法优点总结
    (1)能获取到更多的上下文信息(RSU块,ReSidual U-blocks)
    (2)增加网络深度但没有增加计算量。并且可以从0开始训练,不用从分类预训练网络中再训练
  • 模型大小
      U2-Net (176.3 MB, 30 FPS on GTX 1080Ti GPU)
      U2-Net†(4.7 MB, 40 FPS)

2. 介绍

  • 现有的SOD网络存在什么问题?
    (1)现有的模式基本都是使用已有的backbone,例如AlexNet、VGG、ResNet。这些基础的网络都是为分类任务而设计的,提取的特征更多是语义特征,而不是定位特征和全局对比的信息。
    (2)耗用大量的资源
    (3)牺牲高分辨率的特征映射来实现更深层次的体系结构
  • U2Net的目标是网络更深、使用的资源和计算量更少、能够保持高分辨率的特征图。怎么做呢?
    (1)用两级的内嵌U型结构,不使用分类的backbone
    (2)新型的网络结构更深、能获取高分辨率图像、不增加内存和计算量

3. 网络架构

  • 卷积结构和RSU结构比对
    在这里插入图片描述

(1)( a ) Plain convolution blockPLN
     ( b ) Residual-like block RES
     ( c ) Dense-like block DSE
     ( d ) Inception-like block INC
     ( e ) Our residual U-blockRSU
(2)(a)到( c )是典型的卷积结构,用了1x1和3x3的卷积,感受野太小,只能用来获取local feature
(3)(d)用了空洞卷积增大了感受野,但是需要大的内存和计算资源
(4)RSU-L模块,(L代表层数),Cin:输入通道,Cout:输出通道,M:RSU内部通道

  • 开销比对
    在这里插入图片描述
    RSU的开销(overhead)不大,因为都是下采样,DSE和INC比较大
  • 残差结构比对
    在这里插入图片描述
    (1)残差块:H(x) = F2(F1(x))+x,H(x)是x的映射,F1和F2是卷积操作【对应两个weight layer】
    (2)RSU:HRSU (x) = U(F1(x))+F1(x),RSU和残差不同的地方,是将卷积替换成像Unet的U型结构U-block,原来的输入x替换成F1(x)【weight layer之后】
  • 网络架构
    在这里插入图片描述

  U-Net-like这种结构本来就有,只不过是级联起来,Uxn Net,而作者提出来的是 Un Net,用内嵌(nested)结构而不是级联结构
(1)结构特点:11个stage,每个stage都是RSU结构
   🔸 a six stages encoder
   🔸a five stages decoder
   🔸a saliency map fusion module attached with the decoder stages and the last encoder stage
(2)编码器:
   🔹En_1、En_2、En_3、En_4(即前四个)用到的RSU层数是 RSU-7、 RSU-6、 RSU-5、 RSU-4,层数越多,尺度信息越丰富
   🔹En-5和En-6用了RSU-4F,用了空洞卷积,保证了输入输出是相同的分辨率
(3)解码器:
   De-5也是用了RSU-4F,和En-5、En-6类似
(4)融合模块(saliency map fusion module):
   编码器和解码器的输出,经过3x3卷积和sigmoid,upsample,输出了6个概率热力图:S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1) ,用1x1卷积进行融合,产生了S_fuse

4. loss函数

在这里插入图片描述
✅总Loss等于所有loss之和,包括S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1),和融合的S_fuse
在这里插入图片描述
✅每一层的S_side(x)的loss,使用了二分类交叉熵损失函数

5. 作者实验结果

在这里插入图片描述
Red, Green, and Blue indicate the best, second best and third best performance
在这里插入图片描述

(三)如何训练自己的数据

1. 标注

用labelme标注图片,生成json文件
在这里插入图片描述

2. mask图像

将json文件转换为mask图片,背景黑色,物体白色,下面是转换代码:

python">import cv2
import json
import numpy as np
import os
import sys


def func(file:str) -> np.ndarray:
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (255,255,255))

    return png


if  __name__ == "__main__":

    if len(sys.argv) != 3:
        raise ValueError("json文件或目录 输出路径")

    if os.path.isdir(sys.argv[1]):
        for file in os.listdir(sys.argv[1]):
            cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(file)[0]+".png" ), func(os.path.join(sys.argv[1], file)))
    else:
        cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(os.path.basename(sys.argv[1]))[0]+".png"), func(sys.argv[1]))

在这里插入图片描述

转换的mask图像

3. 训练数据集格式

1️⃣在工程目录创建目录:train_data/DUTS/DUTS-TR/DUTS-TR/
2️⃣在第一步骤创建的目录上,创建目录im_aug,将原图放在这
3️⃣在第一步骤创建的目录上,创建目录gt_aug,将转换的mask图放在这

4. 配置文件修改

  打开u2net_train.py,一般可以设置这几项:
  model_name = ‘u2net’ # 用u2net或者u2netp模型进行训练
  epoch_num = 100000 # 训练轮次
  batch_size_train = 12 # batchsize
  save_frq = 2000 # 每2000个iter保存一个模型

5. 训练命令

python u2net_train.py

6. 测试命令

python u2net_test.py

(四)多类别语义分割

  作者提供的代码只实现了二分类的语义分割,U2Net是否可以用来做多类别的语义分割?答案是可以了,下面提供了将二分类语义分割转换为多类别语义分割的方法

1. 实现思路

🔺项目背景:图片有两个类别,分别是螺丝钉和位移线
🔺类别:两个类别+背景,num_class = 3,如果有更多类别,则是n+1类,1是背景
🔺mask图片:二分类时,填充的是0和255;多分类,不同类别可以填充为0(背景)、1(螺丝钉)、2(位移线),所以最多只能分出0~255个类别。查看3个类别的mask,因为像素值只有0、1、2,肉眼看基本是一张黑色图像
🔺模型输出:三个类别,输出三个通道,如[3, 320, 320],每一个通道代表一个类别

2. 修改方法

(1)获取多类别训练mask脚本

python">import cv2
import json
import numpy as np
import os
import sys


def func(file):
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        label = shape['label']
        if label == 'lm':
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (1,1,1))
        else:
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (2,2,2))

    return png


if  __name__ == "__main__":
    json_dir = "./train_data/labels_json"
    
    save_dir = './train_data/masks'


    for file in os.listdir(json_dir):
        print(file)
        png = func(os.path.join(json_dir, file))
        print(png.shape)
        save_path = save_dir+'/'+os.path.splitext(file)[0]+".png"
        cv2.imwrite(save_path, png)
        print(save_path)

(2)data_loader.py
   class ToTensor(object)和class ToTensorLab(object)这两个类中,有对label进行归一化操作,去除该操作,因为计算loss的时候,多类别换成交叉熵损失函数,它本身包含了softmax操作
在这里插入图片描述
(3)model/u2net.py
   修改模型输出,作者在class U2NETP(nn.Module)和class U2NET(nn.Module)这两个类用了sigmoid函数,需要修改为直接输出,原因同上

python"># return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
return d0, d1, d2, d3, d4, d5, d6

(4)u2net_train.py
   修改损失函数和模型输出通道,将损失函数由原来的BCELoss,修改为CrossEntropyLoss,并设置模型的输出通道和类别一致

python"># bce_loss = nn.BCELoss(size_average=True)  # 注释
ce_loss = nn.CrossEntropyLoss()  # 添加
python"># def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 注释
#     loss0 = bce_loss(d0, labels_v)
#     loss1 = bce_loss(d1, labels_v)
#     loss2 = bce_loss(d2, labels_v)
#     loss3 = bce_loss(d3, labels_v)
#     loss4 = bce_loss(d4, labels_v)
#     loss5 = bce_loss(d5, labels_v)
#     loss6 = bce_loss(d6, labels_v)

#     loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
#     print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
#     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
#     loss6.data.item()))

#     return loss0, loss

def muti_ce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 添加
    loss0 = ce_loss(d0, labels_v)
    loss1 = ce_loss(d1, labels_v)
    loss2 = ce_loss(d2, labels_v)
    loss3 = ce_loss(d3, labels_v)
    loss4 = ce_loss(d4, labels_v)
    loss5 = ce_loss(d5, labels_v)
    loss6 = ce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    loss6.data.item()))

    return loss0, loss
python"># ------- 3. define model --------
# define the net
n_class = 3
if (model_name == 'u2net'):
    net = U2NET(3, n_class)
elif (model_name == 'u2netp'):
    net = U2NETP(3, n_class)

4. 测试

   该例子中,存在三个类别,分别是背景、螺丝钉、位移线,对应模型三个通道的输出,但模型输出为概率值,如何获取到真实的类别,以及将类别用不同颜色表示出来?可以用下面这个脚本实现模型推理和输出结果图

python">import os
import cv2
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def main():

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp

    num_class = 3

    image_dir = os.path.join(os.getcwd(), 'test_data', 'ls_test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results_ls' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, 'u2net_bce_itr_1000_train_1.046126_tar_0.124982.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,num_class)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,num_class)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']

        image = cv2.imread(img_name_list[i_test])
        image_name = os.path.basename(img_name_list[i_test])

        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
        d1 = d1.squeeze(dim=0)    # torch.Size([1, 3, 320, 320]) -> torch.Size([3, 320, 320])
        
        d1 = F.softmax(d1, dim=0)   # [3, 320, 320] 
        # print(d1[0, :, :])

        predict_np = torch.argmax(d1, dim=0, keepdim=True)
        # print(predict_np.shape)  # [1, 320, 320],3个类别,对应3个通道,获取概率值最高的下标

        predict_np = predict_np.cpu().detach().numpy().squeeze()   # 转到cpu设备

        predict_np = cv2.resize(predict_np, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)  # resize和原图一样的大小
        
        r = predict_np.copy()
        b = predict_np.copy()
        g = predict_np.copy()

        cls = dict([(1, (0, 0, 255)),
                    (2, (255, 0, 255)),
                    (3, (0, 255, 0)),
                    (4, (255, 0, 0)),
                    (5, (255, 255, 0))])
        for c in cls:
            r[r == c] = cls[c][0]
            g[g == c] = cls[c][1]
            b[b == c] = cls[c][2]

        rgb = np.zeros((image.shape[0], image.shape[1], 3))
        # print('类别', np.unique(predict_np))
        rgb[:, :, 0] = r
        rgb[:, :, 1] = g
        rgb[:, :, 2] = b

        im = Image.fromarray(rgb.astype(np.uint8))
        im.save('./test_data/my_results_2/' + str(image_name)[:-4] + '.png')

        del d1,d2,d3,d4,d5,d6,d7

if __name__ == "__main__":
    main()

5. 训练测试效果

   经过少量数据的训练测试,证明U2Net可以用来做多类别语义分割
输入图片

输入测试图片

在这里插入图片描述

模型测试效果

撒花完结🌟🌟🌟


http://www.niftyadmin.cn/n/5004608.html

相关文章

怎么处理zk或redis脑裂

很极端场景会出现脑裂 什么是分布式的脑裂 怎么理解zk脑裂 就是ZK,与客户端可能因为网络原因,客户端A还在跑着后续程序,而zk与客户端之前的心跳断了,此zk就把这节点给删除了,这时另一个客户会加锁成功,就样…

python调用C语言库

1. 在linux下通过gcc生成so库 //请保存为 foo.c #include<stdio.h> #define uint8_t unsigned char #define uint16_t unsigned shorttypedef struct TagMyStruct {char name[10];uint8_t age;int score; } MyStruct,*MyStructPointer;MyStructPointer foo_get_data_…

【分享】docker引发的172.17.x.x网段无法访问

前言: 想搭建一个测试环境&#xff0c;折腾vmware虚拟机&#xff0c;发现公司的172.17网段怎么都访问不了。使用traceroute 发现&#xff0c;一直走172.17.0.1&#xff0c;无论是怎么更改配置&#xff0c;都是如此。 查阅资料发现&#xff0c;当 Docker 启动时&#xff0c;会自…

“酱香拿铁”销售额破亿,产品经理笑了

“一天542万杯&#xff0c;突破1亿元。”最近小伙伴们应该看到了&#xff0c;瑞幸咖啡和茅台联名的“酱香拿铁”大获成功。仅用一天时间&#xff0c;就把产品推上了热搜第一的位置。 图片来源于网络&#xff0c;侵删 虽然现在小编还没有机会尝到酱香拿铁&#xff0c;但在看了众…

Java 中 List 集合取补集

交集 Intersection 英 [ˌɪntəˈsekʃn] 并集 Union 英 [ˈjuːniən] 差集 difference of set 补集 complement set 英 [ˈkɒmplɪment] Java 中 List 集合取交集 Java 中 List 集合取并集 Java 中 List 集合取差集 Java 中 List 集合取补集 # 求两个集合交集的补集 List&l…

《Effective C++中文版,第三版》读书笔记7

条款41&#xff1a; 了解隐式接口和编译期多态 隐式接口&#xff1a; ​ 仅仅由一组有效表达式构成&#xff0c;表达式自身可能看起来很复杂&#xff0c;但它们要求的约束条件一般而言相当直接而明确。 显式接口&#xff1a; ​ 通常由函数的签名式&#xff08;也就是函数名…

​Vue + Element UI前端篇(二):Vue + Element 案例 ​

Vue Element UI 实现权限管理系统 前端篇&#xff08;二&#xff09;&#xff1a;Vue Element 案例 导入项目 打开 Visual Studio Code&#xff0c;File --> add Folder to Workspace&#xff0c;导入我们的项目。 安装 Element 安装依赖 Element 是国内饿了么公司提…

ChatGPT是否可以协助人们提高公共演讲和表达能力?

ChatGPT作为一种自然语言处理的AI技术&#xff0c;具有潜在的能力协助人们提高公共演讲和表达能力。公共演讲和表达是重要的沟通技能&#xff0c;对于职业和个人发展都具有关键性的作用。本文将探讨ChatGPT如何在这方面发挥作用&#xff0c;包括以下几个方面&#xff1a; 1. *…