U-Net模型进行训练钢材表面缺陷语义分割数据集 通过钢材缺陷分割数据集的权重模型,推理识别钢材分割

张开发
2026/4/15 13:16:30 15 分钟阅读

分享文章

U-Net模型进行训练钢材表面缺陷语义分割数据集 通过钢材缺陷分割数据集的权重模型,推理识别钢材分割
U-Net模型进行训练钢材表面缺陷语义分割数据集 通过钢材缺陷分割数据集的权重模型推理识别钢材分割文章目录环境搭建1. 安装CUDA驱动2. 安装Anaconda3. 创建Python虚拟环境4. 安装依赖项数据集准备使用U-Net训练模型训练代码推理代码以下文字及代码仅供参考学习使用。钢材表面缺陷语义分割数据集4432张数据jpg和mask掩码png有颜色映射关系另外有转换成coco格式json和yolo格式txt三种缺陷类型像素标签0为背景 1为夹杂物In 2为补丁Pa 3为划痕Sc共4432张数据jpg和mask掩码png有颜色映射关系另外有转换成coco格式json和yolo格式txt含三种缺陷类型像素标签0为背景 1为夹杂物In 2为补丁Pa 3为划痕Scmask标签颜色映射为了使用U-Net模型对钢材表面缺陷进行语义分割我们需要从环境搭建开始到数据集准备、模型训练和推理。以下是详细的步骤指南。仅供参考学习使用labelme查看环境搭建1. 安装CUDA驱动确保您的系统已经安装了与GPU兼容的CUDA驱动版本。可以使用以下命令检查nvidia-smi2. 安装Anaconda访问 Anaconda官网 下载并安装适合您操作系统的版本。3. 创建Python虚拟环境打开终端或Anaconda Prompt然后输入以下命令来创建并激活新的Python环境conda create--nameunet_envpython3.9conda activate unet_env4. 安装依赖项在激活的环境中运行以下命令以安装必要的库pipinstalltorch torchvision torchaudio pipinstallopencv-python pipinstallmatplotlib pipinstallscikit-image pipinstallalbumentations pipinstalltqdm pipinstalltimm pipinstallsegmentation-models-pytorch数据集准备假设同学你的数据集按照如下结构组织steel_defect_dataset/ ├── images/ │ ├── train/ │ ├── val/ │ └── test/ ├── masks/ │ ├── train/ │ ├── val/ │ └── test/ └── data.yamldata.yaml文件内容示例请根据实际情况调整路径train_images:./steel_defect_dataset/images/traintrain_masks:./steel_defect_dataset/masks/trainval_images:./steel_defect_dataset/images/valval_masks:./steel_defect_dataset/masks/valtest_images:./steel_defect_dataset/images/testtest_masks:./steel_defect_dataset/masks/testnc:3names:[In,Pa,Sc]使用U-Net训练模型使用segmentation_models.pytorch库中的U-Net模型进行训练。的Python脚本示例用于加载U-Net模型并使用提供的数据集进行训练。仅供参考学习使用。训练代码首先编写一个数据加载器函数用于加载图像和掩码并应用必要的预处理。importosfromtorch.utils.dataimportDataset,DataLoaderfromPILimportImageimportnumpyasnpfromtorchvisionimporttransformsclassSteelDefectDataset(Dataset):def__init__(self,img_dir,mask_dir,transformNone):self.img_dirimg_dir self.mask_dirmask_dir self.transformtransform self.imagesos.listdir(img_dir)def__len__(self):returnlen(self.images)def__getitem__(self,idx):img_pathos.path.join(self.img_dir,self.images[idx])mask_pathos.path.join(self.mask_dir,self.images[idx].replace(.jpg,.png))imagenp.array(Image.open(img_path).convert(RGB))masknp.array(Image.open(mask_path).convert(L),dtypenp.float32)mask[mask255.0]1.0# 背景为0其他类别为1,2,3ifself.transformisnotNone:augmentationsself.transform(imageimage,maskmask)imageaugmentations[image]maskaugmentations[mask]returnimage,mask# 数据增强importalbumentationsasAfromalbumentations.pytorchimportToTensorV2 transformA.Compose([A.Resize(height256,width256),A.Normalize(mean[0.0,0.0,0.0],std[1.0,1.0,1.0],max_pixel_value255.0,),ToTensorV2(),],)train_dsSteelDefectDataset(img_dirpath/to/train/images,mask_dirpath/to/train/masks,transformtransform,)val_dsSteelDefectDataset(img_dirpath/to/val/images,mask_dirpath/to/val/masks,transformtransform,)train_loaderDataLoader(train_ds,batch_size16,shuffleTrue)val_loaderDataLoader(val_ds,batch_size16,shuffleFalse)接下来是训练部分importtorchimporttorch.nnasnnfromsegmentation_models_pytorchimportUnetfromtqdmimporttqdm# 初始化模型modelUnet(encoder_nameresnet34,classes3,activationNone)# 损失函数和优化器loss_fnnn.CrossEntropyLoss()optimizertorch.optim.Adam(model.parameters(),lr1e-4)# 设备配置devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)model.to(device)# 训练循环deftrain_model(model,train_loader,val_loader,loss_fn,optimizer,num_epochs10):forepochinrange(num_epochs):model.train()looptqdm(train_loader)forbatch_idx,(data,targets)inenumerate(loop):datadata.to(devicedevice)targetstargets.long().to(devicedevice)# 前向传播predictionsmodel(data)lossloss_fn(predictions,targets)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 更新进度条loop.set_postfix(lossloss.item())# 验证阶段model.eval()withtorch.no_grad():num_correct0num_pixels0dice_score0fordata,targetsinval_loader:datadata.to(device)targetstargets.to(device).unsqueeze(1)predictionstorch.softmax(model(data),dim1)predstorch.argmax(predictions,dim1).float()num_correct(predstargets).sum()num_pixelstorch.numel(preds)dice_score(2*(preds*targets).sum())/((predstargets).sum()1e-8)print(fGot{num_correct}/{num_pixels}with acc{num_correct/num_pixels*100:.2f})print(fDice score:{dice_score/len(val_loader)})# 开始训练train_model(model,train_loader,val_loader,loss_fn,optimizer,num_epochs10)推理代码训练完成后您可以使用训练好的模型对新图片进行预测。以下是一个简单的例子importcv2fromtorchvisionimporttransforms# 加载训练好的模型modelUnet(encoder_nameresnet34,classes3,activationNone)model.load_state_dict(torch.load(path/to/best_model.pth))model.eval()# 图像预处理preprocesstransforms.Compose([transforms.ToPILImage(),transforms.Resize((256,256)),transforms.ToTensor(),transforms.Normalize(mean[0.0,0.0,0.0],std[1.0,1.0,1.0]),])# 对单张图片进行预测image_pathpath/to/new/image.jpgimgcv2.imread(image_path)img_tensorpreprocess(img).unsqueeze(0).to(device)withtorch.no_grad():outputmodel(img_tensor)predictiontorch.argmax(output.squeeze(),dim0).cpu().numpy()# 显示结果deflabel_to_color_image(label):colormapnp.array([[0,0,0],[255,0,0],[0,255,0],[0,0,255]])returncolormap[label]color_predictionlabel_to_color_image(prediction)cv2.imshow(Prediction,color_prediction)cv2.waitKey(0)cv2.destroyAllWindows()

更多文章