Diffusion Model实战:从零开始用PyTorch实现图像生成(附完整代码)

张开发
2026/5/13 1:41:44 15 分钟阅读

分享文章

Diffusion Model实战:从零开始用PyTorch实现图像生成(附完整代码)
Diffusion Model实战从零开始用PyTorch实现图像生成附完整代码在生成式AI领域扩散模型Diffusion Model正迅速成为最受关注的技术之一。不同于GAN的对抗训练或VAE的变分推断扩散模型通过一种独特的破坏-重建机制在图像生成质量上实现了突破性进展。本文将带您从PyTorch实现的角度完整走通扩散模型的代码实现流程涵盖噪声调度、训练优化、采样加速等核心环节最终生成MNIST手写数字和CIFAR-10自然图像。1. 环境准备与数据加载1.1 基础环境配置推荐使用Python 3.8和PyTorch 1.12环境确保已安装以下依赖库pip install torch torchvision matplotlib numpy tqdm对于GPU加速需额外配置CUDA环境。可通过以下代码检查设备可用性import torch device torch.device(cuda if torch.cuda.is_available() else cpu) print(fUsing device: {device})1.2 数据预处理流程以MNIST数据集为例我们需要实现特定的transform管道from torchvision import transforms from torchvision.datasets import MNIST transform transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: (x * 2) - 1) # 将像素值归一化到[-1,1] ]) dataset MNIST(./data, trainTrue, downloadTrue, transformtransform) dataloader torch.utils.data.DataLoader(dataset, batch_size128, shuffleTrue)注意扩散模型对输入数据的尺度敏感保持-1到1的范围有利于训练稳定性2. 噪声调度系统设计2.1 余弦噪声调度器相比线性调度余弦调度在图像生成质量上表现更优def cosine_beta_schedule(timesteps, s0.008): 余弦噪声调度 timesteps: 总扩散步数 s: 控制起始点的偏移量 steps timesteps 1 x torch.linspace(0, timesteps, steps) alphas_cumprod torch.cos(((x / timesteps) s) / (1 s) * torch.pi * 0.5) ** 2 alphas_cumprod alphas_cumprod / alphas_cumprod[0] betas 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clamp(betas, 0, 0.999) timesteps 1000 betas cosine_beta_schedule(timesteps)2.2 关键变量预计算提前计算训练所需的中间变量可显著加速过程alphas 1. - betas alphas_cumprod torch.cumprod(alphas, dim0) sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod)3. 核心模型架构实现3.1 时间嵌入层设计时间步信息需要特殊编码后注入网络class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim dim inv_freq torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000) / dim)) self.register_buffer(inv_freq, inv_freq) def forward(self, t): t t.float() pos_enc t[:, None] * self.inv_freq[None, :] return torch.cat([torch.sin(pos_enc), torch.cos(pos_enc)], dim-1)3.2 残差块与注意力机制结合UNet架构与自注意力层class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_dim): super().__init__() self.conv1 nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding1), nn.BatchNorm2d(out_channels), nn.SiLU() ) self.time_proj nn.Linear(time_dim, out_channels) self.conv2 nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, padding1), nn.BatchNorm2d(out_channels), nn.SiLU() ) self.res_conv nn.Conv2d(in_channels, out_channels, 1) if in_channels ! out_channels else nn.Identity() def forward(self, x, t): h self.conv1(x) h self.time_proj(t)[:, :, None, None] h self.conv2(h) return h self.res_conv(x)4. 训练与采样全流程4.1 扩散过程实现前向加噪过程的关键代码def q_sample(x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) sqrt_alpha_cumprod_t extract(sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alpha_cumprod_t extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) return sqrt_alpha_cumprod_t * x_start sqrt_one_minus_alpha_cumprod_t * noise4.2 损失函数优化简化版的噪声预测损失def p_losses(denoise_model, x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) x_noisy q_sample(x_startx_start, tt, noisenoise) predicted_noise denoise_model(x_noisy, t) return F.l1_loss(noise, predicted_noise)4.3 采样过程加速实现DDIM采样策略torch.no_grad() def p_sample_ddim(model, x, t, t_index): betas_t extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t extract( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t extract(torch.sqrt(1.0 / alphas), t, x.shape) # 预测噪声成分 pred_noise model(x, t) # 计算均值 model_mean sqrt_recip_alphas_t * ( x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t ) if t_index 0: return model_mean else: posterior_variance_t extract(posterior_variance, t, x.shape) noise torch.randn_like(x) return model_mean torch.sqrt(posterior_variance_t) * noise5. 实战效果优化技巧5.1 混合精度训练大幅减少显存占用并加速训练scaler torch.cuda.amp.GradScaler() for batch in dataloader: optimizer.zero_grad() x batch[0].to(device) t torch.randint(0, timesteps, (x.shape[0],), devicedevice) with torch.cuda.amp.autocast(): loss p_losses(model, x, t) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5.2 结果可视化生成过程动画制作from matplotlib.animation import FuncAnimation torch.no_grad() def sample_animation(model, image_size, batch_size16, channels1): # 初始化噪声 img torch.randn((batch_size, channels, image_size, image_size), devicedevice) imgs [] for i in reversed(range(0, timesteps)): t torch.full((batch_size,), i, devicedevice, dtypetorch.long) img p_sample_ddim(model, img, t, i) if i % 50 0: imgs.append(img.cpu().numpy()) # 创建动画 fig plt.figure(figsize(12, 6)) ax fig.add_subplot(111) ax.set_axis_off() def update(i): ax.clear() ax.set_axis_off() grid make_grid(torch.Tensor(imgs[i]), nrow4, normalizeTrue) ax.imshow(grid.permute(1, 2, 0)) ax.set_title(fStep {i*50}) anim FuncAnimation(fig, update, frameslen(imgs), interval200) return anim在实际项目中我发现调整噪声调度参数对生成质量影响显著。当使用CIFAR-10数据集时将余弦调度的偏移量s设为0.01同时将总步数增加到2000步可以有效提升生成图像的细节表现力。对于资源有限的情况可以采用渐进式训练策略——先在小分辨率图像上训练基础模型再逐步提升分辨率。

更多文章