图像分割实战-系列教程6:unet医学细胞分割实战4(医学数据集、图像分割、语义分割、unet网络、代码逐行解读)

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

7、单个epoch的训练函数解析

def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    model.train()
    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([ ('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])
  1. 定义训练函数,传入参数:配置信息、训练数据Dataloader、模型、损失函数、优化器
  2. 创建一个字典记录loss和iou,其中AverageMeter类的代码为:
 class AverageMeter(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
  1. 模型进入训练模式
  2. 创建进度条,按照总批次来显示
  3. 遍历Dataloader取出训练数据和标签
  4. 训练数据进入GPU
  5. 训练标签进入GPU
  6. 是否设置了每个位置加入监督,如果是,则
  7. 从模型中得到输出
  8. 当前loss置0
  9. 遍历所有的输出
  10. 求出所有输出的损失,并且累加到loss中
  11. 求出平均loss
  12. 根据模型最后一个输出和标签使用iou_score函数计算iou,iou_score函数代码为:
 def iou_score(output, target):
    smooth = 1e-5
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()
    return (intersection + smooth) / (union + smooth)
  1. 如果不是设置了每个位置加入监督
  2. 从模型中得到输出
  3. 损失函数计算损失
  4. 根据模型所有输出和标签使用iou_score函数计算iou
  5. 梯度清零
  6. 向传播计算得到每个参数的梯度值
  7. 通过梯度下降执行一步参数更新
  8. 更新平均损失
  9. 更新平均iou
  10. 构建 postfix 字典展示进度条,从avg_meters 中相应的 AverageMeter 对象获取的当前平均损失和 IoU 值
  11. 更新进度条
  12. 关闭进度条
  13. 返回一个包含平均损失和 IoU 值的有序字典

8、单个epoch的验证函数

def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    model.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            input = input.cuda()
            target = target.cuda()
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou = iou_score(output, target)
            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            postfix = OrderedDict([ ('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg), ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()
    return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])

验证函数大部分内容与训练函数一致,只不过一个模型训练模式,一个是模型推理模式。此外验证函数中没有反向传播,梯度清零、梯度计算、参数更新等。


http://www.niftyadmin.cn/n/5298324.html

相关文章

【Windows编程】期末复习题2

系列文章目录 期末复习题1 文章目录 系列文章目录解释下列名词的含义?设备环境(描述表)保存了哪些信息?有什么作用?模态对话框与非模态对话框有什么区别?在程序设计中,经常要用到线程&#xff…

Flink学习-处理函数

简介 处理函数是Flink底层的函数,工作中通常用来做一些更复杂的业务处理,处理函数分好几种,主要包括基本处理函数,keyed处理函数,window处理函数。 Flink提供了8种不同处理函数: ProcessFunction&#x…

动画墙纸:将视频、网页、游戏、模拟器变成windows墙纸——Lively Wallpaper

文章目录 前言下载github地址:网盘 关于VideoWebpagesYoutube和流媒体ShadersGIFs游戏和应用程序& more:Performance:多监视器支持:完结 前言 Lively Wallpaper是一款开源的视频壁纸桌面软件,类似 Wallpaper Engine,兼容 Wal…

vue3-13

token可以是后端api的访问依据,一般绝大多数时候,前端要访问后端的api,后端都要求前端请求需要携带一个有效的token,这个token用于用户的身份校验,通过了校验,后端才会向前端返回数据,进行相应的操作,如果没…

类的加载顺序问题-demo展示

面试的的时候经常会被问到包含静态代码块、实例代码块和构造器等代码结构的加载顺序问题,下面借用一个面试题,回顾一下类的代码加载顺序。 public class AooTest {public static void main(String[] args) {AooTest.f1();}static AooTest test1 new Ao…

图片预览 element-plus 带页码

vue3、element-plus项目中&#xff0c;点击预览图片&#xff0c;并显示页码效果如图 安装 | Element Plus <div class"image__preview"><el-imagestyle"width: 100px; height: 100px":src"imgListArr[0]":zoom-rate"1.2":max…

SpringCloud-高级篇(九)

&#xff08;1&#xff09;Seata高可用 我们学习了Seata的各种用法了&#xff0c;Seata的服务是单节点部署的&#xff0c;这个服务如果挂了&#xff0c;整个事务都没有办法完了&#xff0c;下面我们学习Seata的高可用的知识。 实现高可用&#xff0c;还是比较简单&#xff0c;…

07-2-接口文档管理工具-swagger注解使用__ev

swagger参考demo package com.example.swagger2.controller;import com.example.swagger2.exception.SwaggerException; import com.example.swagger2.model.User; import io.swagger.annotations.*; import org.springframework.web.bind.annotation.*;import java.util.Has…