图像分割实战-系列教程8:unet医学细胞分割实战6(医学数据集、图像分割、语义分割、unet网络、代码逐行解读)

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

unet医学细胞分割实战1
unet医学细胞分割实战2
unet医学细胞分割实战3
unet医学细胞分割实战4
unet医学细胞分割实战5
unet医学细胞分割实战6

10、val.py解读

在结束训练后,已经保存了模型,现在可以对模型进行验证,运行过程需要指定参数:

"""
需要指定参数:--name dsb2018_96_NestedUNet_woDS
"""

10.1 基本参数与主要工具包

import argparse
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import archs
from dataset import Dataset
from metrics import iou_score
from utils import AverageMeter
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', default=None, help='model name')
    args = parser.parse_args()
    return args

这里在前面的train.py已经有过解析

10.2 main函数解析

def main():
    args = parse_args()

    with open('models/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-'*20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-'*20)

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    # Data loading code
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    model.load_state_dict(torch.load('models/%s/model.pth' %
                                     config['name']))
    model.eval()

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    avg_meter = AverageMeter()

    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            iou = iou_score(output, target)
            avg_meter.update(iou, input.size(0))

            output = torch.sigmoid(output).cpu().numpy()

            for i in range(len(output)):
                for c in range(config['num_classes']):
                    cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),
                                (output[i, c] * 255).astype('uint8'))

    print('IoU: %.4f' % avg_meter.avg)
    
    plot_examples(input, target, model,num_examples=3)
    torch.cuda.empty_cache()

10.3 验证结果展示

def plot_examples(datax, datay, model,num_examples=6):
    fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
    m = datax.shape[0]
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
        ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][0].set_title("Orignal Image")
        ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
        ax[row_num][1].set_title("Segmented Image localization")
        ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
        ax[row_num][2].set_title("Target image")
    plt.show()

最终输出结果显示:
在这里插入图片描述

unet医学细胞分割实战1
unet医学细胞分割实战2
unet医学细胞分割实战3
unet医学细胞分割实战4
unet医学细胞分割实战5
unet医学细胞分割实战6


http://www.niftyadmin.cn/n/5298849.html

相关文章

Three.js基础入门介绍——Three.js学习四【模型导入】

模型导入 通过Three.js的材质和几何体,我们可以很方便的创建基础3D模型,但涉及到复杂模型时,一般是由专业建模工具 生成模型 文件再导入的方式将模型引入到我们的3D场景中进行使用 Three.js加载器 Three.js提供多种加载器以支持市面上多种…

idea的pom.xml文件灰色删除线解决办法

以上是点击了移除module后就变成这样 如果再次对着已移除的module右键会发现有个delete,点击这个是真删了,要谨慎备份哦 解决方案:恢复误操作remove module的解决方法 idea最右边,有个Maven控件,找到要恢复的module&a…

【华为机试】2023年真题B卷(python)-猴子爬山

一、题目 题目描述: 一天一只顽猴想去从山脚爬到山顶,途中经过一个有个N个台阶的阶梯,但是这猴子有一个习惯: 每一次只能跳1步或跳3步,试问猴子通过这个阶梯有多少种不同的跳跃方式? 二、输入输出 输入描述…

Matlab技巧[绘画逻辑分析仪产生的数据]

绘画逻辑分析仪产生的数据 逻分上抓到了ADC数字信号,一共是10Bit,12MHZ的波形: 这里用并口协议已经解析出数据: 导出csv表格数据(这个数据为补码,所以要做数据转换): 现在要把这个数据绘制成波形,用Python和表格直接绘制速度太慢了,转了一圈发现MATLAB很好用,操作方法如下:…

C语言编写Windows程序:组合启用/禁用Telnet客户端,并Telnet指定ip和端口

本文程序是将启用/禁用Telnet客户端的命令进行组合&#xff0c;单个命令的解析可参考文章&#xff1a; 启用/禁用Windows功能中的Telnet客户端的命令_()命令将阻止使用telnintel-CSDN博客 源代码如下&#xff1a; #include <stdio.h> #include <stdlib.h> #include…

C++日期类的实现

前言&#xff1a;在类和对象比较熟悉的情况下&#xff0c;我们我们就可以开始制作日期表了&#xff0c;实现日期类所包含的知识点有构造函数&#xff0c;析构函数&#xff0c;函数重载&#xff0c;拷贝构造函数&#xff0c;运算符重载&#xff0c;const成员函数 1.日期类的加减…

DrGraph原理示教 - OpenCV 4 功能 - 阈值

普通阈值 OpenCV中的阈值用于相对于提供的阈值分配像素值。在阈值处理中&#xff0c;将每个像素值与阈值进行比较&#xff0c;如果像素值小于阈值则设置为0&#xff0c;否则设置为最大值&#xff08;一般为255&#xff09;。 在OpenCV中&#xff0c;有多种阈值类型可供选择&am…

MyBatis之配置文件和映射文件

系列文章目录 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 MyBatis之配置文件和映射文件 提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 系列文章目录前言一、配置文…