别再死记硬背U-Net结构了!用PyTorch手撸一个能跑通的细胞分割模型(附完整代码)

张开发
2026/6/11 7:10:43 15 分钟阅读

分享文章

别再死记硬背U-Net结构了!用PyTorch手撸一个能跑通的细胞分割模型(附完整代码)
别再死记硬背U-Net结构了用PyTorch手撸一个能跑通的细胞分割模型附完整代码在深度学习领域U-Net因其独特的编码器-解码器结构和跳跃连接机制成为医学图像分割任务中的标杆模型。但很多初学者在学习时容易陷入两个极端要么死磕论文中的网络结构图却写不出可运行的代码要么直接调用现成库而不理解内部实现逻辑。本文将带你用PyTorch从零构建一个完整的U-Net模型并在ISBI细胞分割数据集上实战演练打通从理论到实践的最后一公里。我们将采用分模块实现→整体组装→训练优化的渐进式路线每个环节都配有可复用的代码片段。不同于单纯讲解网络结构这里更关注工程实现中的关键细节比如如何处理医学图像的特殊数据格式跳跃连接的具体实现技巧Dice Loss的梯度计算陷阱小样本下的数据增强策略1. 环境准备与数据加载1.1 基础环境配置推荐使用Python 3.8和PyTorch 1.12环境主要依赖库包括pip install torch torchvision pip install opencv-python pip install scikit-image pip install tqdm对于医学图像处理还需要安装专门的库import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from skimage import io, transform import numpy as np import matplotlib.pyplot as plt1.2 ISBI数据集处理ISBI细胞分割数据集包含30张训练图像和30张测试图像每张图像尺寸为512x512。我们需要自定义Dataset类class CellDataset(Dataset): def __init__(self, img_dir, mask_dir, transformNone): self.img_dir Path(img_dir) self.mask_dir Path(mask_dir) self.transform transform self.images sorted(self.img_dir.glob(*.tif)) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path self.images[idx] mask_path self.mask_dir / img_path.name.replace(.tif, _mask.tif) image io.imread(img_path) mask io.imread(mask_path) if self.transform: image self.transform(image) mask self.transform(mask) return image, mask注意医学图像通常需要特殊预处理如归一化到[0,1]范围并转换为Tensor格式transform Compose([ ToPILImage(), Resize((256, 256)), # 适当降低分辨率加速训练 ToTensor(), Normalize(mean[0.5], std[0.5]) ])2. U-Net核心模块实现2.1 基础卷积块U-Net的每个阶段都包含两个3x3卷积ReLU的组合我们将其封装为可复用的DoubleConv模块class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.double_conv nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, kernel_size3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) def forward(self, x): return self.double_conv(x)2.2 下采样模块编码器部分通过最大池化实现下采样class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x)2.3 上采样模块解码器部分使用转置卷积实现上采样并与编码器的特征图拼接class Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up nn.ConvTranspose2d( in_channels, in_channels // 2, kernel_size2, stride2) self.conv DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 self.up(x1) # 计算padding确保尺寸匹配 diffY x2.size()[2] - x1.size()[2] diffX x2.size()[3] - x1.size()[3] x1 F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x torch.cat([x2, x1], dim1) return self.conv(x)3. 完整U-Net模型组装3.1 网络架构实现整合各模块构建完整U-Netclass UNet(nn.Module): def __init__(self, n_channels1, n_classes1): super().__init__() self.n_channels n_channels self.n_classes n_classes self.inc DoubleConv(n_channels, 64) self.down1 Down(64, 128) self.down2 Down(128, 256) self.down3 Down(256, 512) self.down4 Down(512, 1024) self.up1 Up(1024, 512) self.up2 Up(512, 256) self.up3 Up(256, 128) self.up4 Up(128, 64) self.outc nn.Conv2d(64, n_classes, kernel_size1) def forward(self, x): x1 self.inc(x) x2 self.down1(x1) x3 self.down2(x2) x4 self.down3(x3) x5 self.down4(x4) x self.up1(x5, x4) x self.up2(x, x3) x self.up3(x, x2) x self.up4(x, x1) logits self.outc(x) return torch.sigmoid(logits)3.2 模型参数初始化采用He初始化提升训练稳定性def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) if m.bias is not None: nn.init.constant_(m.bias, 0) model UNet() model.apply(init_weights)4. 训练与评估策略4.1 混合损失函数结合Dice Loss和BCE Loss解决类别不平衡问题class DiceBCELoss(nn.Module): def __init__(self, smooth1.): super().__init__() self.smooth smooth def forward(self, inputs, targets): inputs inputs.view(-1) targets targets.view(-1) intersection (inputs * targets).sum() dice_loss 1 - (2.*intersection self.smooth) / (inputs.sum() targets.sum() self.smooth) BCE F.binary_cross_entropy(inputs, targets, reductionmean) return BCE dice_loss4.2 训练循环实现完整的训练流程包含验证阶段def train_model(model, criterion, optimizer, dataloaders, num_epochs25): best_model_wts copy.deepcopy(model.state_dict()) best_loss float(inf) for epoch in range(num_epochs): for phase in [train, val]: if phase train: model.train() else: model.eval() running_loss 0.0 for inputs, labels in dataloaders[phase]: inputs inputs.to(device) labels labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase train): outputs model(inputs) loss criterion(outputs, labels) if phase train: loss.backward() optimizer.step() running_loss loss.item() * inputs.size(0) epoch_loss running_loss / len(dataloaders[phase].dataset) if phase val and epoch_loss best_loss: best_loss epoch_loss best_model_wts copy.deepcopy(model.state_dict()) model.load_state_dict(best_model_wts) return model4.3 评估指标计算除常规IoU外添加特定于医学图像的评估指标def calculate_metrics(pred, target, threshold0.5): pred (pred threshold).float() target target.float() tp (pred * target).sum() fp (pred * (1-target)).sum() fn ((1-pred) * target).sum() precision tp / (tp fp 1e-7) recall tp / (tp fn 1e-7) dice (2 * tp) / (2 * tp fp fn 1e-7) return precision, recall, dice5. 实战技巧与优化策略5.1 小样本数据增强针对医学图像数据量少的特点采用特殊增强策略class MedicalTransform: def __call__(self, sample): image, mask sample # 随机弹性变形 if random.random() 0.5: alpha random.randint(100, 200) sigma random.randint(8, 12) image elastic_transform(image, alphaalpha, sigmasigma) mask elastic_transform(mask, alphaalpha, sigmasigma) # 随机灰度变化 if random.random() 0.5: gamma random.uniform(0.8, 1.2) image adjust_gamma(image, gammagamma) return image, mask5.2 学习率动态调整使用余弦退火策略平衡收敛速度和精度scheduler torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_maxnum_epochs, eta_min1e-5)5.3 模型推理与可视化训练完成后可视化分割结果def plot_results(model, dataloader, num_images3): model.eval() with torch.no_grad(): for i, (images, masks) in enumerate(dataloader): if i num_images: break outputs model(images.to(device)) preds (outputs 0.5).float() plt.figure(figsize(12, 4)) plt.subplot(1, 3, 1) plt.imshow(images[0].permute(1, 2, 0), cmapgray) plt.title(Input) plt.subplot(1, 3, 2) plt.imshow(masks[0].permute(1, 2, 0), cmapgray) plt.title(Ground Truth) plt.subplot(1, 3, 3) plt.imshow(preds[0].permute(1, 2, 0).cpu(), cmapgray) plt.title(Prediction) plt.show()在实际项目中我发现模型对细胞边缘的分割效果往往不够理想。通过添加边缘检测损失作为辅助任务可以显著提升边界区域的预测精度。另一个实用技巧是在最后上采样阶段使用亚像素卷积代替转置卷积能减少棋盘格伪影。

更多文章